From 07a3df6001a261df6d573aeef99e4ee00970f8c6 Mon Sep 17 00:00:00 2001 From: Ben Alkov Date: Sat, 15 Oct 2022 17:09:47 -0400 Subject: [PATCH 01/54] DRAFT: Cross-Attention Control Signed-off-by: Ben Alkov --- c_a_c.py | 177 ++++++++++++++++++++++++++++++++++++++ cross_attention_loop.py | 185 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 362 insertions(+) create mode 100644 c_a_c.py create mode 100644 cross_attention_loop.py diff --git a/c_a_c.py b/c_a_c.py new file mode 100644 index 0000000000..33243edaf7 --- /dev/null +++ b/c_a_c.py @@ -0,0 +1,177 @@ +# Functions supporting Cross-Attention Control +# Copied from https://github.com/bloc97/CrossAttentionControl + +from difflib import SequenceMatcher + +import torch + + +def prompt_token(prompt, index, clip_tokenizer): + tokens = clip_tokenizer(prompt, + padding='max_length', + max_length=clip_tokenizer.model_max_length, + truncation=True, + return_tensors='pt', + return_overflowing_tokens=True + ).input_ids[0] + return clip_tokenizer.decode(tokens[index:index+1]) + + +def init_attention_weights(weight_tuples, clip_tokenizer, unet, device): + tokens_length = clip_tokenizer.model_max_length + weights = torch.ones(tokens_length) + + for i, w in weight_tuples: + if i < tokens_length and i >= 0: + weights[i] = w + + for name, module in unet.named_modules(): + module_name = type(module).__name__ + if module_name == 'CrossAttention' and 'attn2' in name: + module.last_attn_slice_weights = weights.to(device) + if module_name == 'CrossAttention' and 'attn1' in name: + module.last_attn_slice_weights = None + + +def init_attention_edit(tokens, tokens_edit, clip_tokenizer, unet, device): + tokens_length = clip_tokenizer.model_max_length + mask = torch.zeros(tokens_length) + indices_target = torch.arange(tokens_length, dtype=torch.long) + indices = torch.zeros(tokens_length, dtype=torch.long) + + tokens = tokens.input_ids.numpy()[0] + tokens_edit = tokens_edit.input_ids.numpy()[0] + + for name, a0, a1, b0, b1 in SequenceMatcher(None, tokens, tokens_edit).get_opcodes(): + if b0 < tokens_length: + if name == 'equal' or (name == 'replace' and a1-a0 == b1-b0): + mask[b0:b1] = 1 + indices[b0:b1] = indices_target[a0:a1] + + for name, module in unet.named_modules(): + module_name = type(module).__name__ + if module_name == 'CrossAttention' and 'attn2' in name: + module.last_attn_slice_mask = mask.to(device) + module.last_attn_slice_indices = indices.to(device) + if module_name == 'CrossAttention' and 'attn1' in name: + module.last_attn_slice_mask = None + module.last_attn_slice_indices = None + + +def init_attention_func(unet): + # ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276 + def new_attention(self, query, key, value): + # TODO: use baddbmm for better performance + attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale + attn_slice = attention_scores.softmax(dim=-1) + # compute attention output + + if self.use_last_attn_slice: + if self.last_attn_slice_mask is not None: + new_attn_slice = (torch.index_select(self.last_attn_slice, -1, + self.last_attn_slice_indices)) + attn_slice = (attn_slice * (1 - self.last_attn_slice_mask) + + new_attn_slice * self.last_attn_slice_mask) + else: + attn_slice = self.last_attn_slice + + self.use_last_attn_slice = False + + if self.save_last_attn_slice: + self.last_attn_slice = attn_slice + self.save_last_attn_slice = False + + if self.use_last_attn_weights and self.last_attn_slice_weights is not None: + attn_slice = attn_slice * self.last_attn_slice_weights + self.use_last_attn_weights = False + + hidden_states = torch.matmul(attn_slice, value) + # reshape hidden_states + return self.reshape_batch_dim_to_heads(hidden_states) + + def new_sliced_attention(self, query, key, value, sequence_length, dim): + + batch_size_attention = query.shape[0] + hidden_states = torch.zeros( + (batch_size_attention, sequence_length, dim // self.heads), + device=query.device, dtype=query.dtype + ) + slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] + for i in range(hidden_states.shape[0] // slice_size): + start_idx = i * slice_size + end_idx = (i + 1) * slice_size + attn_slice = ( + torch.matmul(query[start_idx:end_idx], + key[start_idx:end_idx].transpose(1, 2)) * self.scale + ) # TODO: use baddbmm for better performance + attn_slice = attn_slice.softmax(dim=-1) + + if self.use_last_attn_slice: + if self.last_attn_slice_mask is not None: + new_attn_slice = (torch.index_select(self.last_attn_slice, + -1, self.last_attn_slice_indices)) + attn_slice = (attn_slice * (1 - self.last_attn_slice_mask) + + new_attn_slice * self.last_attn_slice_mask) + else: + attn_slice = self.last_attn_slice + + self.use_last_attn_slice = False + + if self.save_last_attn_slice: + self.last_attn_slice = attn_slice + self.save_last_attn_slice = False + + if self.use_last_attn_weights and self.last_attn_slice_weights is not None: + attn_slice = attn_slice * self.last_attn_slice_weights + self.use_last_attn_weights = False + + attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + + return self.reshape_batch_dim_to_heads(hidden_states) # reshape hidden_states + + for _, module in unet.named_modules(): + module_name = type(module).__name__ + if module_name == 'CrossAttention': + module.last_attn_slice = None + module.use_last_attn_slice = False + module.use_last_attn_weights = False + module.save_last_attn_slice = False + module._sliced_attention = new_sliced_attention.__get__(module, type(module)) + module._attention = new_attention.__get__(module, type(module)) + + +def use_last_tokens_attention(unet, use=True): + for name, module in unet.named_modules(): + module_name = type(module).__name__ + if module_name == 'CrossAttention' and 'attn2' in name: + module.use_last_attn_slice = use + + +def use_last_tokens_attention_weights(unet, use=True): + for name, module in unet.named_modules(): + module_name = type(module).__name__ + if module_name == 'CrossAttention' and 'attn2' in name: + module.use_last_attn_weights = use + + +def use_last_self_attention(unet, use=True): + for name, module in unet.named_modules(): + module_name = type(module).__name__ + if module_name == 'CrossAttention' and 'attn1' in name: + module.use_last_attn_slice = use + + +def save_last_tokens_attention(unet, save=True): + for name, module in unet.named_modules(): + module_name = type(module).__name__ + if module_name == 'CrossAttention' and 'attn2' in name: + module.save_last_attn_slice = save + + +def save_last_self_attention(unet, save=True): + for name, module in unet.named_modules(): + module_name = type(module).__name__ + if module_name == 'CrossAttention' and 'attn1' in name: + module.save_last_attn_slice = save diff --git a/cross_attention_loop.py b/cross_attention_loop.py new file mode 100644 index 0000000000..e2d6eb6201 --- /dev/null +++ b/cross_attention_loop.py @@ -0,0 +1,185 @@ +import random + +import numpy as np +import torch + +from diffusers import (LMSDiscreteScheduler) +from PIL import Image +from torch import autocast +from tqdm.auto import tqdm + +import c_a_c + + +@torch.no_grad() +def stablediffusion( + clip, + clip_tokenizer, + device, + vae, + unet, + prompt='', + prompt_edit=None, + prompt_edit_token_weights=None, + prompt_edit_tokens_start=0.0, + prompt_edit_tokens_end=1.0, + prompt_edit_spatial_start=0.0, + prompt_edit_spatial_end=1.0, + guidance_scale=7.5, + steps=50, + seed=None, + width=512, + height=512, + init_image=None, + init_image_strength=0.5, + ): + if prompt_edit_token_weights is None: + prompt_edit_token_weights = [] + # Change size to multiple of 64 to prevent size mismatches inside model + width = width - width % 64 + height = height - height % 64 + + # If seed is None, randomly select seed from 0 to 2^32-1 + if seed is None: seed = random.randrange(2**32 - 1) + generator = torch.cuda.manual_seed(seed) + + # Set inference timesteps to scheduler + scheduler = LMSDiscreteScheduler(beta_start=0.00085, + beta_end=0.012, + beta_schedule='scaled_linear', + num_train_timesteps=1000, + ) + scheduler.set_timesteps(steps) + + # Preprocess image if it exists (img2img) + if init_image is not None: + # Resize and transpose for numpy b h w c -> torch b c h w + init_image = init_image.resize((width, height), resample=Image.Resampling.LANCZOS) + init_image = np.array(init_image).astype(np.float32) / 255.0 * 2.0 - 1.0 + init_image = torch.from_numpy(init_image[np.newaxis, ...].transpose(0, 3, 1, 2)) + + # If there is alpha channel, composite alpha for white, as the diffusion + # model does not support alpha channel + if init_image.shape[1] > 3: + init_image = init_image[:, :3] * init_image[:, 3:] + (1 - init_image[:, 3:]) + + # Move image to GPU + init_image = init_image.to(device) + + # Encode image + with autocast(device): + init_latent = (vae.encode(init_image) + .latent_dist + .sample(generator=generator) + * 0.18215) + + t_start = steps - int(steps * init_image_strength) + + else: + init_latent = torch.zeros((1, unet.in_channels, height // 8, width // 8), + device=device) + t_start = 0 + + # Generate random normal noise + noise = torch.randn(init_latent.shape, generator=generator, device=device) + latent = scheduler.add_noise(init_latent, + noise, + torch.tensor([scheduler.timesteps[t_start]], device=device) + ).to(device) + + # Process clip + with autocast(device): + tokens_uncond = clip_tokenizer('', padding='max_length', + max_length=clip_tokenizer.model_max_length, + truncation=True, return_tensors='pt', + return_overflowing_tokens=True + ) + embedding_uncond = clip(tokens_uncond.input_ids.to(device)).last_hidden_state + + tokens_cond = clip_tokenizer(prompt, padding='max_length', + max_length=clip_tokenizer.model_max_length, + truncation=True, return_tensors='pt', + return_overflowing_tokens=True + ) + embedding_cond = clip(tokens_cond.input_ids.to(device)).last_hidden_state + + # Process prompt editing + if prompt_edit is not None: + tokens_cond_edit = clip_tokenizer(prompt_edit, padding='max_length', + max_length=clip_tokenizer.model_max_length, + truncation=True, return_tensors='pt', + return_overflowing_tokens=True + ) + embedding_cond_edit = clip(tokens_cond_edit.input_ids.to(device)).last_hidden_state + + c_a_c.init_attention_edit(tokens_cond, tokens_cond_edit) + + c_a_c.init_attention_func() + c_a_c.init_attention_weights(prompt_edit_token_weights) + + timesteps = scheduler.timesteps[t_start:] + + for idx, timestep in tqdm(enumerate(timesteps), total=len(timesteps)): + t_index = t_start + idx + + latent_model_input = latent + latent_model_input = scheduler.scale_model_input(latent_model_input, timestep) + + # Predict the unconditional noise residual + noise_pred_uncond = unet(latent_model_input, + timestep, + encoder_hidden_states=embedding_uncond + ).sample + + # Prepare the Cross-Attention layers + if prompt_edit is not None: + c_a_c.save_last_tokens_attention() + c_a_c.save_last_self_attention() + else: + # Use weights on non-edited prompt when edit is None + c_a_c.use_last_tokens_attention_weights() + + # Predict the conditional noise residual and save the + # cross-attention layer activations + noise_pred_cond = unet(latent_model_input, + timestep, + encoder_hidden_states=embedding_cond + ).sample + + # Edit the Cross-Attention layer activations + if prompt_edit is not None: + t_scale = timestep / scheduler.num_train_timesteps + if (t_scale >= prompt_edit_tokens_start + and t_scale <= prompt_edit_tokens_end): + c_a_c.use_last_tokens_attention() + if (t_scale >= prompt_edit_spatial_start + and t_scale <= prompt_edit_spatial_end): + c_a_c.use_last_self_attention() + + # Use weights on edited prompt + c_a_c.use_last_tokens_attention_weights() + + # Predict the edited conditional noise residual using the + # cross-attention masks + noise_pred_cond = unet(latent_model_input, + timestep, + encoder_hidden_states=embedding_cond_edit + ).sample + + # Perform guidance + noise_pred = (noise_pred_uncond + guidance_scale + * (noise_pred_cond - noise_pred_uncond)) + + latent = scheduler.step(noise_pred, + t_index, + latent + ).prev_sample + + # scale and decode the image latents with vae + latent = latent / 0.18215 + image = vae.decode(latent.to(vae.dtype)).sample + + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy() + image = (image[0] * 255).round().astype('uint8') + return Image.fromarray(image) From b0b19939183c58e2a502aafad4bceddcc33575fe Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Wed, 12 Oct 2022 23:29:48 +0200 Subject: [PATCH 02/54] initial experiments --- ldm/models/diffusion/ksampler.py | 16 ++++++++++++++++ ldm/models/diffusion/sampler.py | 5 +++++ 2 files changed, 21 insertions(+) diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index ac0615b30c..55800d0a5c 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -1,4 +1,6 @@ """wrapper around part of Katherine Crowson's k-diffusion library, making it call compatible with other Samplers""" +from enum import Enum + import k_diffusion as K import torch import torch.nn as nn @@ -25,6 +27,9 @@ def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7): minval = max(min(-1, scale*minval), -threshold) return torch.clamp(result, min=minval, max=maxval) +class AttentionLayer(Enum): + SELF = 1 + TOKENS = 2 class CFGDenoiser(nn.Module): def __init__(self, model, threshold = 0, warmup = 0): @@ -34,11 +39,22 @@ class CFGDenoiser(nn.Module): self.warmup_max = warmup self.warmup = max(warmup / 10, 1) + + def get_attention_module(self, which: AttentionLayer): + which_attn = "attn1" if which is AttentionLayer.SELF else "attn2" + module = next(module for name,module in self.inner_model.named_modules() if + type(module).__name__ == "CrossAttention" and which_attn in name) + return module + + def forward(self, x, sigma, uncond, cond, cond_scale): x_in = torch.cat([x] * 2) sigma_in = torch.cat([sigma] * 2) cond_in = torch.cat([uncond, cond]) uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) + + module = self.get_attention_module(AttentionLayer.TOKENS) + if self.warmup < self.warmup_max: thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max)) self.warmup += 1 diff --git a/ldm/models/diffusion/sampler.py b/ldm/models/diffusion/sampler.py index ff705513f8..eb7caebba0 100644 --- a/ldm/models/diffusion/sampler.py +++ b/ldm/models/diffusion/sampler.py @@ -4,6 +4,8 @@ ldm.models.diffusion.sampler Base class for ldm.models.diffusion.ddim, ldm.models.diffusion.ksampler, etc ''' +from enum import Enum + import torch import numpy as np from tqdm import tqdm @@ -411,3 +413,6 @@ class Sampler(object): return self.model.inner_model.q_sample(x0,ts) ''' return self.model.q_sample(x0,ts) + + + From 33d6603fef4839f7627eb817181c1cb59bb3b838 Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Sun, 16 Oct 2022 16:57:48 +0200 Subject: [PATCH 03/54] cleanup initial experiments --- ldm/models/diffusion/cross_attention.py | 19 +++++++++++++++++++ ldm/models/diffusion/ksampler.py | 10 ---------- 2 files changed, 19 insertions(+), 10 deletions(-) create mode 100644 ldm/models/diffusion/cross_attention.py diff --git a/ldm/models/diffusion/cross_attention.py b/ldm/models/diffusion/cross_attention.py new file mode 100644 index 0000000000..c39d8d5959 --- /dev/null +++ b/ldm/models/diffusion/cross_attention.py @@ -0,0 +1,19 @@ +from enum import Enum + + +class CrossAttention: + + class AttentionType(Enum): + SELF = 1 + TOKENS = 2 + + @classmethod + def get_attention_module(cls, model, which: AttentionType): + which_attn = "attn1" if which is cls.AttentionType.SELF else "attn2" + module = next(module for name, module in model.named_modules() if + type(module).__name__ == "CrossAttention" and which_attn in name) + return module + + @classmethod + def inject_attention_mask_capture(cls, model, callback): + pass \ No newline at end of file diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index 55800d0a5c..8010b44d1d 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -27,9 +27,6 @@ def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7): minval = max(min(-1, scale*minval), -threshold) return torch.clamp(result, min=minval, max=maxval) -class AttentionLayer(Enum): - SELF = 1 - TOKENS = 2 class CFGDenoiser(nn.Module): def __init__(self, model, threshold = 0, warmup = 0): @@ -40,13 +37,6 @@ class CFGDenoiser(nn.Module): self.warmup = max(warmup / 10, 1) - def get_attention_module(self, which: AttentionLayer): - which_attn = "attn1" if which is AttentionLayer.SELF else "attn2" - module = next(module for name,module in self.inner_model.named_modules() if - type(module).__name__ == "CrossAttention" and which_attn in name) - return module - - def forward(self, x, sigma, uncond, cond, cond_scale): x_in = torch.cat([x] * 2) sigma_in = torch.cat([sigma] * 2) From 8ff507b03b9c958ff7c68bc48a5f0e03b5a47706 Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Sun, 16 Oct 2022 20:39:47 +0200 Subject: [PATCH 04/54] runs but doesn't work properly - see below for test prompt test prompt: "a cat sitting on a car {a dog sitting on a car}" -W 384 -H 256 -s 10 -S 12346 -A k_euler note that substition of dog for cat is currently hard-coded (ksampler.py line 43-44) --- c_a_c.py | 177 ------------------------ cross_attention_loop.py | 5 +- ldm/generate.py | 12 +- ldm/invoke/conditioning.py | 16 ++- ldm/invoke/generator/txt2img.py | 3 +- ldm/models/diffusion/cross_attention.py | 149 +++++++++++++++++++- ldm/models/diffusion/ksampler.py | 33 ++++- ldm/modules/attention.py | 11 +- 8 files changed, 207 insertions(+), 199 deletions(-) delete mode 100644 c_a_c.py diff --git a/c_a_c.py b/c_a_c.py deleted file mode 100644 index 33243edaf7..0000000000 --- a/c_a_c.py +++ /dev/null @@ -1,177 +0,0 @@ -# Functions supporting Cross-Attention Control -# Copied from https://github.com/bloc97/CrossAttentionControl - -from difflib import SequenceMatcher - -import torch - - -def prompt_token(prompt, index, clip_tokenizer): - tokens = clip_tokenizer(prompt, - padding='max_length', - max_length=clip_tokenizer.model_max_length, - truncation=True, - return_tensors='pt', - return_overflowing_tokens=True - ).input_ids[0] - return clip_tokenizer.decode(tokens[index:index+1]) - - -def init_attention_weights(weight_tuples, clip_tokenizer, unet, device): - tokens_length = clip_tokenizer.model_max_length - weights = torch.ones(tokens_length) - - for i, w in weight_tuples: - if i < tokens_length and i >= 0: - weights[i] = w - - for name, module in unet.named_modules(): - module_name = type(module).__name__ - if module_name == 'CrossAttention' and 'attn2' in name: - module.last_attn_slice_weights = weights.to(device) - if module_name == 'CrossAttention' and 'attn1' in name: - module.last_attn_slice_weights = None - - -def init_attention_edit(tokens, tokens_edit, clip_tokenizer, unet, device): - tokens_length = clip_tokenizer.model_max_length - mask = torch.zeros(tokens_length) - indices_target = torch.arange(tokens_length, dtype=torch.long) - indices = torch.zeros(tokens_length, dtype=torch.long) - - tokens = tokens.input_ids.numpy()[0] - tokens_edit = tokens_edit.input_ids.numpy()[0] - - for name, a0, a1, b0, b1 in SequenceMatcher(None, tokens, tokens_edit).get_opcodes(): - if b0 < tokens_length: - if name == 'equal' or (name == 'replace' and a1-a0 == b1-b0): - mask[b0:b1] = 1 - indices[b0:b1] = indices_target[a0:a1] - - for name, module in unet.named_modules(): - module_name = type(module).__name__ - if module_name == 'CrossAttention' and 'attn2' in name: - module.last_attn_slice_mask = mask.to(device) - module.last_attn_slice_indices = indices.to(device) - if module_name == 'CrossAttention' and 'attn1' in name: - module.last_attn_slice_mask = None - module.last_attn_slice_indices = None - - -def init_attention_func(unet): - # ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276 - def new_attention(self, query, key, value): - # TODO: use baddbmm for better performance - attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale - attn_slice = attention_scores.softmax(dim=-1) - # compute attention output - - if self.use_last_attn_slice: - if self.last_attn_slice_mask is not None: - new_attn_slice = (torch.index_select(self.last_attn_slice, -1, - self.last_attn_slice_indices)) - attn_slice = (attn_slice * (1 - self.last_attn_slice_mask) - + new_attn_slice * self.last_attn_slice_mask) - else: - attn_slice = self.last_attn_slice - - self.use_last_attn_slice = False - - if self.save_last_attn_slice: - self.last_attn_slice = attn_slice - self.save_last_attn_slice = False - - if self.use_last_attn_weights and self.last_attn_slice_weights is not None: - attn_slice = attn_slice * self.last_attn_slice_weights - self.use_last_attn_weights = False - - hidden_states = torch.matmul(attn_slice, value) - # reshape hidden_states - return self.reshape_batch_dim_to_heads(hidden_states) - - def new_sliced_attention(self, query, key, value, sequence_length, dim): - - batch_size_attention = query.shape[0] - hidden_states = torch.zeros( - (batch_size_attention, sequence_length, dim // self.heads), - device=query.device, dtype=query.dtype - ) - slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] - for i in range(hidden_states.shape[0] // slice_size): - start_idx = i * slice_size - end_idx = (i + 1) * slice_size - attn_slice = ( - torch.matmul(query[start_idx:end_idx], - key[start_idx:end_idx].transpose(1, 2)) * self.scale - ) # TODO: use baddbmm for better performance - attn_slice = attn_slice.softmax(dim=-1) - - if self.use_last_attn_slice: - if self.last_attn_slice_mask is not None: - new_attn_slice = (torch.index_select(self.last_attn_slice, - -1, self.last_attn_slice_indices)) - attn_slice = (attn_slice * (1 - self.last_attn_slice_mask) - + new_attn_slice * self.last_attn_slice_mask) - else: - attn_slice = self.last_attn_slice - - self.use_last_attn_slice = False - - if self.save_last_attn_slice: - self.last_attn_slice = attn_slice - self.save_last_attn_slice = False - - if self.use_last_attn_weights and self.last_attn_slice_weights is not None: - attn_slice = attn_slice * self.last_attn_slice_weights - self.use_last_attn_weights = False - - attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx]) - - hidden_states[start_idx:end_idx] = attn_slice - - return self.reshape_batch_dim_to_heads(hidden_states) # reshape hidden_states - - for _, module in unet.named_modules(): - module_name = type(module).__name__ - if module_name == 'CrossAttention': - module.last_attn_slice = None - module.use_last_attn_slice = False - module.use_last_attn_weights = False - module.save_last_attn_slice = False - module._sliced_attention = new_sliced_attention.__get__(module, type(module)) - module._attention = new_attention.__get__(module, type(module)) - - -def use_last_tokens_attention(unet, use=True): - for name, module in unet.named_modules(): - module_name = type(module).__name__ - if module_name == 'CrossAttention' and 'attn2' in name: - module.use_last_attn_slice = use - - -def use_last_tokens_attention_weights(unet, use=True): - for name, module in unet.named_modules(): - module_name = type(module).__name__ - if module_name == 'CrossAttention' and 'attn2' in name: - module.use_last_attn_weights = use - - -def use_last_self_attention(unet, use=True): - for name, module in unet.named_modules(): - module_name = type(module).__name__ - if module_name == 'CrossAttention' and 'attn1' in name: - module.use_last_attn_slice = use - - -def save_last_tokens_attention(unet, save=True): - for name, module in unet.named_modules(): - module_name = type(module).__name__ - if module_name == 'CrossAttention' and 'attn2' in name: - module.save_last_attn_slice = save - - -def save_last_self_attention(unet, save=True): - for name, module in unet.named_modules(): - module_name = type(module).__name__ - if module_name == 'CrossAttention' and 'attn1' in name: - module.save_last_attn_slice = save diff --git a/cross_attention_loop.py b/cross_attention_loop.py index e2d6eb6201..ed3e3b0462 100644 --- a/cross_attention_loop.py +++ b/cross_attention_loop.py @@ -1,4 +1,5 @@ import random +import traceback import numpy as np import torch @@ -8,7 +9,7 @@ from PIL import Image from torch import autocast from tqdm.auto import tqdm -import c_a_c +import .ldm.models.diffusion.cross_attention @torch.no_grad() @@ -41,7 +42,7 @@ def stablediffusion( # If seed is None, randomly select seed from 0 to 2^32-1 if seed is None: seed = random.randrange(2**32 - 1) - generator = torch.cuda.manual_seed(seed) + generator = torch.manual_seed(seed) # Set inference timesteps to scheduler scheduler = LMSDiscreteScheduler(beta_start=0.00085, diff --git a/ldm/generate.py b/ldm/generate.py index 7fb68dec0a..b8945342b0 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -32,7 +32,7 @@ from ldm.invoke.pngwriter import PngWriter from ldm.invoke.args import metadata_from_png from ldm.invoke.image_util import InitImageResizer from ldm.invoke.devices import choose_torch_device, choose_precision -from ldm.invoke.conditioning import get_uc_and_c +from ldm.invoke.conditioning import get_uc_and_c_and_ec from ldm.invoke.model_cache import ModelCache from ldm.invoke.seamless import configure_model_padding from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale @@ -400,7 +400,7 @@ class Generate: mask_image = None try: - uc, c = get_uc_and_c( + uc, c, ec = get_uc_and_c_and_ec( prompt, model =self.model, skip_normalize=skip_normalize, log_tokens =self.log_tokenization @@ -438,7 +438,7 @@ class Generate: sampler=self.sampler, steps=steps, cfg_scale=cfg_scale, - conditioning=(uc, c), + conditioning=(uc, c, ec), ddim_eta=ddim_eta, image_callback=image_callback, # called after the final image is generated step_callback=step_callback, # called after each intermediate image is generated @@ -469,14 +469,14 @@ class Generate: save_original = save_original, image_callback = image_callback) - except RuntimeError as e: - print(traceback.format_exc(), file=sys.stderr) - print('>> Could not generate image.') except KeyboardInterrupt: if catch_interrupts: print('**Interrupted** Partial results will be returned.') else: raise KeyboardInterrupt + except (RuntimeError, Exception) as e: + print(traceback.format_exc(), file=sys.stderr) + print('>> Could not generate image.') toc = time.time() print('>> Usage stats:') diff --git a/ldm/invoke/conditioning.py b/ldm/invoke/conditioning.py index fedd965a2c..1453d9ce8c 100644 --- a/ldm/invoke/conditioning.py +++ b/ldm/invoke/conditioning.py @@ -12,7 +12,7 @@ log_tokenization() print out colour-coded tokens and warn if trunca import re import torch -def get_uc_and_c(prompt, model, log_tokens=False, skip_normalize=False): +def get_uc_and_c_and_ec(prompt, model, log_tokens=False, skip_normalize=False): # Extract Unconditioned Words From Prompt unconditioned_words = '' unconditional_regex = r'\[(.*?)\]' @@ -26,7 +26,19 @@ def get_uc_and_c(prompt, model, log_tokens=False, skip_normalize=False): clean_prompt = unconditional_regex_compile.sub(' ', prompt) prompt = re.sub(' +', ' ', clean_prompt) + edited_words = None + edited_regex = r'\{(.*?)\}' + edited = re.findall(edited_regex, prompt) + if len(edited) > 0: + edited_words = ' '.join(edited) + edited_regex_compile = re.compile(edited_regex) + clean_prompt = edited_regex_compile.sub(' ', prompt) + prompt = re.sub(' +', ' ', clean_prompt) + uc = model.get_learned_conditioning([unconditioned_words]) + ec = None + if edited_words is not None: + ec = model.get_learned_conditioning([edited_words]) # get weighted sub-prompts weighted_subprompts = split_weighted_subprompts( @@ -48,7 +60,7 @@ def get_uc_and_c(prompt, model, log_tokens=False, skip_normalize=False): log_tokenization(prompt, model, log_tokens, 1) c = model.get_learned_conditioning([prompt]) uc = model.get_learned_conditioning([unconditioned_words]) - return (uc, c) + return (uc, c, ec) def split_weighted_subprompts(text, skip_normalize=False)->list: """ diff --git a/ldm/invoke/generator/txt2img.py b/ldm/invoke/generator/txt2img.py index edd12c948c..23f03f22db 100644 --- a/ldm/invoke/generator/txt2img.py +++ b/ldm/invoke/generator/txt2img.py @@ -19,7 +19,7 @@ class Txt2Img(Generator): kwargs are 'width' and 'height' """ self.perlin = perlin - uc, c = conditioning + uc, c, ec = conditioning @torch.no_grad() def make_image(x_T): @@ -43,6 +43,7 @@ class Txt2Img(Generator): verbose = False, unconditional_guidance_scale = cfg_scale, unconditional_conditioning = uc, + edited_conditioning = ec, eta = ddim_eta, img_callback = step_callback, threshold = threshold, diff --git a/ldm/models/diffusion/cross_attention.py b/ldm/models/diffusion/cross_attention.py index c39d8d5959..a440eb3e6a 100644 --- a/ldm/models/diffusion/cross_attention.py +++ b/ldm/models/diffusion/cross_attention.py @@ -1,8 +1,8 @@ from enum import Enum +import torch -class CrossAttention: - +class CrossAttentionControl: class AttentionType(Enum): SELF = 1 TOKENS = 2 @@ -15,5 +15,146 @@ class CrossAttention: return module @classmethod - def inject_attention_mask_capture(cls, model, callback): - pass \ No newline at end of file + def setup_attention_editing(cls, model, original_tokens_length: int, + substitute_conditioning: torch.Tensor = None, + token_indices_to_edit: list = None): + + # adapted from init_attention_edit + self_attention_module = cls.get_attention_module(model, cls.AttentionType.SELF) + tokens_attention_module = cls.get_attention_module(model, cls.AttentionType.TOKENS) + + if substitute_conditioning is not None: + + device = substitute_conditioning.device + + # this is not very torch-y + mask = torch.zeros(original_tokens_length) + for i in token_indices_to_edit: + mask[i] = 1 + + self_attention_module.last_attn_slice_mask = None + self_attention_module.last_attn_slice_indices = None + tokens_attention_module.last_attn_slice_mask = mask.to(device) + tokens_attention_module.last_attn_slice_indices = torch.tensor(token_indices_to_edit, device=device) + + cls.inject_attention_functions(model) + + @classmethod + def request_save_attention_maps(cls, model): + self_attention_module = cls.get_attention_module(model, cls.AttentionType.SELF) + tokens_attention_module = cls.get_attention_module(model, cls.AttentionType.TOKENS) + self_attention_module.save_last_attn_slice = True + tokens_attention_module.save_last_attn_slice = True + + @classmethod + def request_apply_saved_attention_maps(cls, model): + self_attention_module = cls.get_attention_module(model, cls.AttentionType.SELF) + tokens_attention_module = cls.get_attention_module(model, cls.AttentionType.TOKENS) + self_attention_module.use_last_attn_slice = True + tokens_attention_module.use_last_attn_slice = True + + @classmethod + def inject_attention_functions(cls, unet): + # ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276 + def new_attention(self, query, key, value): + # TODO: use baddbmm for better performance + print(f"entered new_attention") + attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale + attn_slice = attention_scores.softmax(dim=-1) + # compute attention output + + if self.use_last_attn_slice: + if self.last_attn_slice_mask is not None: + print('using masked last_attn_slice') + + new_attn_slice = (torch.index_select(self.last_attn_slice, -1, + self.last_attn_slice_indices)) + attn_slice = (attn_slice * (1 - self.last_attn_slice_mask) + + new_attn_slice * self.last_attn_slice_mask) + else: + print('using unmasked last_attn_slice') + attn_slice = self.last_attn_slice + + self.use_last_attn_slice = False + else: + print('not using last_attn_slice') + + if self.save_last_attn_slice: + print('saving last_attn_slice') + self.last_attn_slice = attn_slice + self.save_last_attn_slice = False + else: + print('not saving last_attn_slice') + + if self.use_last_attn_weights and self.last_attn_slice_weights is not None: + attn_slice = attn_slice * self.last_attn_slice_weights + self.use_last_attn_weights = False + + hidden_states = torch.matmul(attn_slice, value) + # reshape hidden_states + return hidden_states + + for _, module in unet.named_modules(): + module_name = type(module).__name__ + if module_name == 'CrossAttention': + module.last_attn_slice = None + module.use_last_attn_slice = False + module.use_last_attn_weights = False + module.save_last_attn_slice = False + module.cross_attention_callback = new_attention.__get__(module, type(module)) + + +# original code below + +# Functions supporting Cross-Attention Control +# Copied from https://github.com/bloc97/CrossAttentionControl + +from difflib import SequenceMatcher + +import torch + + +def prompt_token(prompt, index, clip_tokenizer): + tokens = clip_tokenizer(prompt, + padding='max_length', + max_length=clip_tokenizer.model_max_length, + truncation=True, + return_tensors='pt', + return_overflowing_tokens=True + ).input_ids[0] + return clip_tokenizer.decode(tokens[index:index + 1]) + + +def use_last_tokens_attention(unet, use=True): + for name, module in unet.named_modules(): + module_name = type(module).__name__ + if module_name == 'CrossAttention' and 'attn2' in name: + module.use_last_attn_slice = use + + +def use_last_tokens_attention_weights(unet, use=True): + for name, module in unet.named_modules(): + module_name = type(module).__name__ + if module_name == 'CrossAttention' and 'attn2' in name: + module.use_last_attn_weights = use + + +def use_last_self_attention(unet, use=True): + for name, module in unet.named_modules(): + module_name = type(module).__name__ + if module_name == 'CrossAttention' and 'attn1' in name: + module.use_last_attn_slice = use + + +def save_last_tokens_attention(unet, save=True): + for name, module in unet.named_modules(): + module_name = type(module).__name__ + if module_name == 'CrossAttention' and 'attn2' in name: + module.save_last_attn_slice = save + + +def save_last_self_attention(unet, save=True): + for name, module in unet.named_modules(): + module_name = type(module).__name__ + if module_name == 'CrossAttention' and 'attn1' in name: + module.save_last_attn_slice = save diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index 8010b44d1d..29949aff8d 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -13,6 +13,7 @@ from ldm.modules.diffusionmodules.util import ( noise_like, extract_into_tensor, ) +from ldm.models.diffusion.cross_attention import CrossAttentionControl def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7): if threshold <= 0.0: @@ -29,21 +30,41 @@ def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7): class CFGDenoiser(nn.Module): - def __init__(self, model, threshold = 0, warmup = 0): + def __init__(self, model, threshold = 0, warmup = 0, edited_conditioning = None): super().__init__() self.inner_model = model self.threshold = threshold self.warmup_max = warmup self.warmup = max(warmup / 10, 1) + self.edited_conditioning = edited_conditioning + + if self.edited_conditioning is not None: + initial_tokens_count = 77 # ' a cat sitting on a car ' + token_indices_to_edit = [2] # 'cat' + CrossAttentionControl.setup_attention_editing(self.inner_model, initial_tokens_count, edited_conditioning, token_indices_to_edit) def forward(self, x, sigma, uncond, cond, cond_scale): x_in = torch.cat([x] * 2) sigma_in = torch.cat([sigma] * 2) cond_in = torch.cat([uncond, cond]) - uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) - module = self.get_attention_module(AttentionLayer.TOKENS) + print('generating new unconditioned latents') + unconditioned_latents = self.inner_model(x, sigma, cond=uncond) + + # process x using the original prompt, saving the attention maps if required + if self.edited_conditioning is not None: + # this is automatically toggled off after the model forward() + CrossAttentionControl.request_save_attention_maps(self.inner_model) + print('generating new conditioned latents') + conditioned_latents = self.inner_model(x, sigma, cond=cond) + + if self.edited_conditioning is not None: + # process x again, using the saved attention maps but the new conditioning + # this is automatically toggled off after the model forward() + CrossAttentionControl.request_apply_saved_attention_maps(self.inner_model) + print('generating edited conditioned latents') + conditioned_latents = self.inner_model(x, sigma, cond=self.edited_conditioning) if self.warmup < self.warmup_max: thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max)) @@ -52,7 +73,8 @@ class CFGDenoiser(nn.Module): thresh = self.threshold if thresh > self.threshold: thresh = self.threshold - return cfg_apply_threshold(uncond + (cond - uncond) * cond_scale, thresh) + delta = (conditioned_latents - unconditioned_latents) + return cfg_apply_threshold(unconditioned_latents + delta * cond_scale, thresh) class KSampler(Sampler): @@ -169,6 +191,7 @@ class KSampler(Sampler): log_every_t=100, unconditional_guidance_scale=1.0, unconditional_conditioning=None, + edited_conditioning=None, threshold = 0, perlin = 0, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... @@ -200,7 +223,7 @@ class KSampler(Sampler): else: x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] - model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10)) + model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10), edited_conditioning=edited_conditioning) extra_args = { 'cond': conditioning, 'uncond': unconditional_conditioning, diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index ef9c2d3e65..a9805e6c67 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -170,6 +170,8 @@ class CrossAttention(nn.Module): self.mem_total_gb = psutil.virtual_memory().total // (1 << 30) + self.cross_attention_callback = None + def einsum_op_compvis(self, q, k, v): s = einsum('b i d, b j d -> b i j', q, k) s = s.softmax(dim=-1, dtype=s.dtype) @@ -244,8 +246,13 @@ class CrossAttention(nn.Module): del context, x q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - r = self.einsum_op(q, k, v) - return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h)) + + if self.cross_attention_callback is not None: + r = self.cross_attention_callback(q, k, v) + else: + r = self.einsum_op(q, k, v) + hidden_states = rearrange(r, '(b h) n d -> b n (h d)', h=h) + return self.to_out(hidden_states) class BasicTransformerBlock(nn.Module): From 1fc1f8bf05c01af0c07714df38f10745dd984103 Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Mon, 17 Oct 2022 21:15:03 +0200 Subject: [PATCH 05/54] cross-attention working with placeholder {} syntax --- ldm/generate.py | 4 +- ldm/invoke/conditioning.py | 34 +- ldm/invoke/generator/txt2img.py | 3 +- ldm/models/diffusion/cross_attention.py | 141 ++++-- ldm/models/diffusion/ksampler.py | 23 +- ldm/modules/attention.py | 554 +++++++++++++++++------- ldm/modules/diffusionmodules/model.py | 10 +- ldm/modules/encoders/modules.py | 2 +- 8 files changed, 534 insertions(+), 237 deletions(-) diff --git a/ldm/generate.py b/ldm/generate.py index b8945342b0..37df973291 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -400,7 +400,7 @@ class Generate: mask_image = None try: - uc, c, ec = get_uc_and_c_and_ec( + uc, c, ec, ec_index_map = get_uc_and_c_and_ec( prompt, model =self.model, skip_normalize=skip_normalize, log_tokens =self.log_tokenization @@ -438,7 +438,7 @@ class Generate: sampler=self.sampler, steps=steps, cfg_scale=cfg_scale, - conditioning=(uc, c, ec), + conditioning=(uc, c, ec, ec_index_map), ddim_eta=ddim_eta, image_callback=image_callback, # called after the final image is generated step_callback=step_callback, # called after each intermediate image is generated diff --git a/ldm/invoke/conditioning.py b/ldm/invoke/conditioning.py index 1453d9ce8c..8c8f5eeb01 100644 --- a/ldm/invoke/conditioning.py +++ b/ldm/invoke/conditioning.py @@ -10,6 +10,8 @@ log_tokenization() print out colour-coded tokens and warn if trunca ''' import re +from difflib import SequenceMatcher + import torch def get_uc_and_c_and_ec(prompt, model, log_tokens=False, skip_normalize=False): @@ -35,32 +37,46 @@ def get_uc_and_c_and_ec(prompt, model, log_tokens=False, skip_normalize=False): clean_prompt = edited_regex_compile.sub(' ', prompt) prompt = re.sub(' +', ' ', clean_prompt) - uc = model.get_learned_conditioning([unconditioned_words]) - ec = None - if edited_words is not None: - ec = model.get_learned_conditioning([edited_words]) - # get weighted sub-prompts weighted_subprompts = split_weighted_subprompts( prompt, skip_normalize ) + ec = None + edit_opcodes = None + + uc, _ = model.get_learned_conditioning([unconditioned_words]) + if len(weighted_subprompts) > 1: # i dont know if this is correct.. but it works c = torch.zeros_like(uc) # normalize each "sub prompt" and add it for subprompt, weight in weighted_subprompts: log_tokenization(subprompt, model, log_tokens, weight) + subprompt_embeddings, _ = model.get_learned_conditioning([subprompt]) c = torch.add( c, - model.get_learned_conditioning([subprompt]), + subprompt_embeddings, alpha=weight, ) + if edited_words is not None: + print("can't do cross-attention control with blends just yet, ignoring edits") else: # just standard 1 prompt log_tokenization(prompt, model, log_tokens, 1) - c = model.get_learned_conditioning([prompt]) - uc = model.get_learned_conditioning([unconditioned_words]) - return (uc, c, ec) + c, c_tokens = model.get_learned_conditioning([prompt]) + if edited_words is not None: + ec, ec_tokens = model.get_learned_conditioning([edited_words]) + edit_opcodes = build_token_edit_opcodes(c_tokens, ec_tokens) + + return (uc, c, ec, edit_opcodes) + +def build_token_edit_opcodes(c_tokens, ec_tokens): + tokens = c_tokens.cpu().numpy()[0] + tokens_edit = ec_tokens.cpu().numpy()[0] + + opcodes = SequenceMatcher(None, tokens, tokens_edit).get_opcodes() + return opcodes + def split_weighted_subprompts(text, skip_normalize=False)->list: """ diff --git a/ldm/invoke/generator/txt2img.py b/ldm/invoke/generator/txt2img.py index 23f03f22db..9f066745f7 100644 --- a/ldm/invoke/generator/txt2img.py +++ b/ldm/invoke/generator/txt2img.py @@ -19,7 +19,7 @@ class Txt2Img(Generator): kwargs are 'width' and 'height' """ self.perlin = perlin - uc, c, ec = conditioning + uc, c, ec, edit_index_map = conditioning @torch.no_grad() def make_image(x_T): @@ -44,6 +44,7 @@ class Txt2Img(Generator): unconditional_guidance_scale = cfg_scale, unconditional_conditioning = uc, edited_conditioning = ec, + edit_token_index_map = edit_index_map, eta = ddim_eta, img_callback = step_callback, threshold = threshold, diff --git a/ldm/models/diffusion/cross_attention.py b/ldm/models/diffusion/cross_attention.py index a440eb3e6a..d829162f35 100644 --- a/ldm/models/diffusion/cross_attention.py +++ b/ldm/models/diffusion/cross_attention.py @@ -2,89 +2,99 @@ from enum import Enum import torch +# adapted from bloc97's CrossAttentionControl colab +# https://github.com/bloc97/CrossAttentionControl + class CrossAttentionControl: class AttentionType(Enum): SELF = 1 TOKENS = 2 @classmethod - def get_attention_module(cls, model, which: AttentionType): - which_attn = "attn1" if which is cls.AttentionType.SELF else "attn2" - module = next(module for name, module in model.named_modules() if - type(module).__name__ == "CrossAttention" and which_attn in name) - return module - - @classmethod - def setup_attention_editing(cls, model, original_tokens_length: int, + def setup_attention_editing(cls, model, substitute_conditioning: torch.Tensor = None, - token_indices_to_edit: list = None): + edit_opcodes: list = None): + """ + :param model: The unet model to inject into. + :param substitute_conditioning: The "edited" conditioning vector, [Bx77x768] + :param edit_opcodes: Opcodes from difflib.SequenceMatcher describing how the base + conditionings map to the "edited" conditionings. + :return: + """ # adapted from init_attention_edit - self_attention_module = cls.get_attention_module(model, cls.AttentionType.SELF) - tokens_attention_module = cls.get_attention_module(model, cls.AttentionType.TOKENS) - if substitute_conditioning is not None: device = substitute_conditioning.device - # this is not very torch-y - mask = torch.zeros(original_tokens_length) - for i in token_indices_to_edit: - mask[i] = 1 + max_length = model.inner_model.cond_stage_model.max_length + # mask=1 means use base prompt attention, mask=0 means use edited prompt attention + mask = torch.zeros(max_length) + indices_target = torch.arange(max_length, dtype=torch.long) + indices = torch.zeros(max_length, dtype=torch.long) + for name, a0, a1, b0, b1 in edit_opcodes: + if b0 < max_length: + if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0): + # these tokens have not been edited + indices[b0:b1] = indices_target[a0:a1] + mask[b0:b1] = 1 - self_attention_module.last_attn_slice_mask = None - self_attention_module.last_attn_slice_indices = None - tokens_attention_module.last_attn_slice_mask = mask.to(device) - tokens_attention_module.last_attn_slice_indices = torch.tensor(token_indices_to_edit, device=device) + for m in cls.get_attention_modules(model, cls.AttentionType.SELF): + m.last_attn_slice_mask = None + m.last_attn_slice_indices = None + + for m in cls.get_attention_modules(model, cls.AttentionType.TOKENS): + m.last_attn_slice_mask = mask.to(device) + m.last_attn_slice_indices = indices.to(device) cls.inject_attention_functions(model) + + @classmethod + def get_attention_modules(cls, model, which: AttentionType): + which_attn = "attn1" if which is cls.AttentionType.SELF else "attn2" + return [module for name, module in model.named_modules() if + type(module).__name__ == "CrossAttention" and which_attn in name] + + @classmethod def request_save_attention_maps(cls, model): - self_attention_module = cls.get_attention_module(model, cls.AttentionType.SELF) - tokens_attention_module = cls.get_attention_module(model, cls.AttentionType.TOKENS) - self_attention_module.save_last_attn_slice = True - tokens_attention_module.save_last_attn_slice = True + self_attention_modules = cls.get_attention_modules(model, cls.AttentionType.SELF) + tokens_attention_modules = cls.get_attention_modules(model, cls.AttentionType.TOKENS) + for m in self_attention_modules+tokens_attention_modules: + m.save_last_attn_slice = True @classmethod def request_apply_saved_attention_maps(cls, model): - self_attention_module = cls.get_attention_module(model, cls.AttentionType.SELF) - tokens_attention_module = cls.get_attention_module(model, cls.AttentionType.TOKENS) - self_attention_module.use_last_attn_slice = True - tokens_attention_module.use_last_attn_slice = True + self_attention_modules = cls.get_attention_modules(model, cls.AttentionType.SELF) + tokens_attention_modules = cls.get_attention_modules(model, cls.AttentionType.TOKENS) + for m in self_attention_modules+tokens_attention_modules: + m.use_last_attn_slice = True + @classmethod def inject_attention_functions(cls, unet): # ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276 def new_attention(self, query, key, value): # TODO: use baddbmm for better performance - print(f"entered new_attention") attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale attn_slice = attention_scores.softmax(dim=-1) # compute attention output if self.use_last_attn_slice: if self.last_attn_slice_mask is not None: - print('using masked last_attn_slice') - - new_attn_slice = (torch.index_select(self.last_attn_slice, -1, - self.last_attn_slice_indices)) - attn_slice = (attn_slice * (1 - self.last_attn_slice_mask) - + new_attn_slice * self.last_attn_slice_mask) + base_attn_slice = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices) + base_attn_slice_mask = self.last_attn_slice_mask + this_attn_slice_mask = 1 - self.last_attn_slice_mask + attn_slice = attn_slice * this_attn_slice_mask + base_attn_slice * base_attn_slice_mask else: - print('using unmasked last_attn_slice') attn_slice = self.last_attn_slice self.use_last_attn_slice = False - else: - print('not using last_attn_slice') if self.save_last_attn_slice: - print('saving last_attn_slice') self.last_attn_slice = attn_slice self.save_last_attn_slice = False - else: - print('not saving last_attn_slice') if self.use_last_attn_weights and self.last_attn_slice_weights is not None: attn_slice = attn_slice * self.last_attn_slice_weights @@ -92,16 +102,59 @@ class CrossAttentionControl: hidden_states = torch.matmul(attn_slice, value) # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) return hidden_states - for _, module in unet.named_modules(): + def new_sliced_attention(self, query, key, value, sequence_length, dim): + + batch_size_attention = query.shape[0] + hidden_states = torch.zeros( + (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype + ) + slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] + for i in range(hidden_states.shape[0] // slice_size): + start_idx = i * slice_size + end_idx = (i + 1) * slice_size + attn_slice = ( + torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale + ) # TODO: use baddbmm for better performance + attn_slice = attn_slice.softmax(dim=-1) + + if self.use_last_attn_slice: + if self.last_attn_slice_mask is not None: + new_attn_slice = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices) + attn_slice = attn_slice * ( + 1 - self.last_attn_slice_mask) + new_attn_slice * self.last_attn_slice_mask + else: + attn_slice = self.last_attn_slice + + self.use_last_attn_slice = False + + if self.save_last_attn_slice: + self.last_attn_slice = attn_slice + self.save_last_attn_slice = False + + if self.use_last_attn_weights and self.last_attn_slice_weights is not None: + attn_slice = attn_slice * self.last_attn_slice_weights + self.use_last_attn_weights = False + + attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + for name, module in unet.named_modules(): module_name = type(module).__name__ - if module_name == 'CrossAttention': + if module_name == "CrossAttention": module.last_attn_slice = None module.use_last_attn_slice = False module.use_last_attn_weights = False module.save_last_attn_slice = False - module.cross_attention_callback = new_attention.__get__(module, type(module)) + module._sliced_attention = new_sliced_attention.__get__(module, type(module)) + module._attention = new_attention.__get__(module, type(module)) # original code below diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index 29949aff8d..e5d521f33f 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -30,7 +30,7 @@ def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7): class CFGDenoiser(nn.Module): - def __init__(self, model, threshold = 0, warmup = 0, edited_conditioning = None): + def __init__(self, model, threshold = 0, warmup = 0, edited_conditioning = None, edit_opcodes = None): super().__init__() self.inner_model = model self.threshold = threshold @@ -39,24 +39,23 @@ class CFGDenoiser(nn.Module): self.edited_conditioning = edited_conditioning - if self.edited_conditioning is not None: - initial_tokens_count = 77 # ' a cat sitting on a car ' - token_indices_to_edit = [2] # 'cat' - CrossAttentionControl.setup_attention_editing(self.inner_model, initial_tokens_count, edited_conditioning, token_indices_to_edit) + if edited_conditioning is not None: + # a cat sitting on a car + CrossAttentionControl.setup_attention_editing(self.inner_model, edited_conditioning, edit_opcodes) + else: + # pass through the attention func but don't act on it + CrossAttentionControl.setup_attention_editing(self.inner_model) def forward(self, x, sigma, uncond, cond, cond_scale): - x_in = torch.cat([x] * 2) - sigma_in = torch.cat([sigma] * 2) - cond_in = torch.cat([uncond, cond]) - print('generating new unconditioned latents') + print('generating unconditioned latents') unconditioned_latents = self.inner_model(x, sigma, cond=uncond) # process x using the original prompt, saving the attention maps if required if self.edited_conditioning is not None: # this is automatically toggled off after the model forward() CrossAttentionControl.request_save_attention_maps(self.inner_model) - print('generating new conditioned latents') + print('generating conditioned latents') conditioned_latents = self.inner_model(x, sigma, cond=cond) if self.edited_conditioning is not None: @@ -192,6 +191,7 @@ class KSampler(Sampler): unconditional_guidance_scale=1.0, unconditional_conditioning=None, edited_conditioning=None, + edit_token_index_map=None, threshold = 0, perlin = 0, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... @@ -223,7 +223,8 @@ class KSampler(Sampler): else: x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] - model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10), edited_conditioning=edited_conditioning) + model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10), + edited_conditioning=edited_conditioning, edit_opcodes=edit_token_index_map) extra_args = { 'cond': conditioning, 'uncond': unconditional_conditioning, diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index a9805e6c67..d00b95b1af 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -1,155 +1,367 @@ -from inspect import isfunction +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import math +from typing import Optional + import torch import torch.nn.functional as F -from torch import nn, einsum -from einops import rearrange, repeat - -from ldm.modules.diffusionmodules.util import checkpoint - -import psutil - -def exists(val): - return val is not None +from torch import nn -def uniq(arr): - return{el: True for el in arr}.keys() +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted + to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + Uses three q, k, v linear layers to compute attention. + + Parameters: + channels (:obj:`int`): The number of channels in the input and output. + num_head_channels (:obj:`int`, *optional*): + The number of channels in each head. If None, then `num_heads` = 1. + num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm. + rescale_output_factor (:obj:`float`, *optional*, defaults to 1.0): The factor to rescale the output by. + eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm. + """ + + def __init__( + self, + channels: int, + num_head_channels: Optional[int] = None, + num_groups: int = 32, + rescale_output_factor: float = 1.0, + eps: float = 1e-5, + ): + super().__init__() + self.channels = channels + + self.num_heads = channels // num_head_channels if num_head_channels is not None else 1 + self.num_head_size = num_head_channels + self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True) + + # define q,k,v as linear layers + self.query = nn.Linear(channels, channels) + self.key = nn.Linear(channels, channels) + self.value = nn.Linear(channels, channels) + + self.rescale_output_factor = rescale_output_factor + self.proj_attn = nn.Linear(channels, channels, 1) + + def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor: + new_projection_shape = projection.size()[:-1] + (self.num_heads, -1) + # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D) + new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3) + return new_projection + + def forward(self, hidden_states): + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.query(hidden_states) + key_proj = self.key(hidden_states) + value_proj = self.value(hidden_states) + + # transpose + query_states = self.transpose_for_scores(query_proj) + key_states = self.transpose_for_scores(key_proj) + value_states = self.transpose_for_scores(value_proj) + + # get scores + scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads)) + attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) # TODO: use baddmm + attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype) + + # compute attention output + hidden_states = torch.matmul(attention_probs, value_states) + + hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous() + new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,) + hidden_states = hidden_states.view(new_hidden_states_shape) + + # compute next hidden_states + hidden_states = self.proj_attn(hidden_states) + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states -def default(val, d): - if exists(val): - return val - return d() if isfunction(d) else d +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply + standard transformer action. Finally, reshape to image. + + Parameters: + in_channels (:obj:`int`): The number of channels in the input and output. + n_heads (:obj:`int`): The number of heads to use for multi-head attention. + d_head (:obj:`int`): The number of channels in each head. + depth (:obj:`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (:obj:`float`, *optional*, defaults to 0.1): The dropout probability to use. + context_dim (:obj:`int`, *optional*): The number of context dimensions to use. + """ + + def __init__( + self, + in_channels: int, + n_heads: int, + d_head: int, + depth: int = 1, + dropout: float = 0.0, + num_groups: int = 32, + context_dim: Optional[int] = None, + ): + super().__init__() + self.n_heads = n_heads + self.d_head = d_head + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) + for d in range(depth) + ] + ) + + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + + def _set_attention_slice(self, slice_size): + for block in self.transformer_blocks: + block._set_attention_slice(slice_size) + + def forward(self, hidden_states, context=None): + # note: if no context is given, cross-attention defaults to self-attention + batch, channel, height, weight = hidden_states.shape + residual = hidden_states + hidden_states = self.norm(hidden_states) + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + for block in self.transformer_blocks: + hidden_states = block(hidden_states, context=context) + hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2) + hidden_states = self.proj_out(hidden_states) + return hidden_states + residual -def max_neg_value(t): - return -torch.finfo(t.dtype).max +class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (:obj:`int`): The number of channels in the input and output. + n_heads (:obj:`int`): The number of heads to use for multi-head attention. + d_head (:obj:`int`): The number of channels in each head. + dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. + context_dim (:obj:`int`, *optional*): The size of the context vector for cross attention. + gated_ff (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use a gated feed-forward network. + checkpoint (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use checkpointing. + """ + + def __init__( + self, + dim: int, + n_heads: int, + d_head: int, + dropout=0.0, + context_dim: Optional[int] = None, + gated_ff: bool = True, + checkpoint: bool = True, + ): + super().__init__() + self.attn1 = CrossAttention( + query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout + ) # is a self-attention + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention( + query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout + ) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def _set_attention_slice(self, slice_size): + self.attn1._slice_size = slice_size + self.attn2._slice_size = slice_size + + def forward(self, hidden_states, context=None): + hidden_states = hidden_states.contiguous() if hidden_states.device.type == "mps" else hidden_states + hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states + hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + return hidden_states -def init_(tensor): - dim = tensor.shape[-1] - std = 1 / math.sqrt(dim) - tensor.uniform_(-std, std) - return tensor +class CrossAttention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (:obj:`int`): The number of channels in the query. + context_dim (:obj:`int`, *optional*): + The number of channels in the context. If not given, defaults to `query_dim`. + heads (:obj:`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. + dim_head (:obj:`int`, *optional*, defaults to 64): The number of channels in each head. + dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. + """ + + def __init__( + self, query_dim: int, context_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: int = 0.0 + ): + super().__init__() + inner_dim = dim_head * heads + context_dim = context_dim if context_dim is not None else query_dim + + self.scale = dim_head**-0.5 + self.heads = heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self._slice_size = None + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def forward(self, hidden_states, context=None, mask=None): + batch_size, sequence_length, _ = hidden_states.shape + + query = self.to_q(hidden_states) + context = context if context is not None else hidden_states + key = self.to_k(context) + value = self.to_v(context) + + dim = query.shape[-1] + + query = self.reshape_heads_to_batch_dim(query) + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + # TODO(PVP) - mask is currently never used. Remember to re-implement when used + + # attention, what we cannot get enough of + + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim) + + return self.to_out(hidden_states) + + def _attention(self, query, key, value): + # TODO: use baddbmm for better performance + attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale + attention_probs = attention_scores.softmax(dim=-1) + # compute attention output + hidden_states = torch.matmul(attention_probs, value) + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + def _sliced_attention(self, query, key, value, sequence_length, dim): + batch_size_attention = query.shape[0] + hidden_states = torch.zeros( + (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype + ) + slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] + for i in range(hidden_states.shape[0] // slice_size): + start_idx = i * slice_size + end_idx = (i + 1) * slice_size + attn_slice = ( + torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale + ) # TODO: use baddbmm for better performance + attn_slice = attn_slice.softmax(dim=-1) + attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (:obj:`int`): The number of channels in the input. + dim_out (:obj:`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (:obj:`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + glu (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use GLU activation. + dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. + """ + + def __init__( + self, dim: int, dim_out: Optional[int] = None, mult: int = 4, glu: bool = False, dropout: float = 0.0 + ): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + project_in = GEGLU(dim, inner_dim) + + self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) + + def forward(self, hidden_states): + return self.net(hidden_states) # feedforward class GEGLU(nn.Module): - def __init__(self, dim_in, dim_out): + r""" + A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. + + Parameters: + dim_in (:obj:`int`): The number of channels in the input. + dim_out (:obj:`int`): The number of channels in the output. + """ + + def __init__(self, dim_in: int, dim_out: int): super().__init__() self.proj = nn.Linear(dim_in, dim_out * 2) - def forward(self, x): - x, gate = self.proj(x).chunk(2, dim=-1) - return x * F.gelu(gate) - - -class FeedForward(nn.Module): - def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): - super().__init__() - inner_dim = int(dim * mult) - dim_out = default(dim_out, dim) - project_in = nn.Sequential( - nn.Linear(dim, inner_dim), - nn.GELU() - ) if not glu else GEGLU(dim, inner_dim) - - self.net = nn.Sequential( - project_in, - nn.Dropout(dropout), - nn.Linear(inner_dim, dim_out) - ) - - def forward(self, x): - return self.net(x) - - -def zero_module(module): - """ - Zero out the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().zero_() - return module - - -def Normalize(in_channels): - return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) - - -class LinearAttention(nn.Module): - def __init__(self, dim, heads=4, dim_head=32): - super().__init__() - self.heads = heads - hidden_dim = dim_head * heads - self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) - self.to_out = nn.Conv2d(hidden_dim, dim, 1) - - def forward(self, x): - b, c, h, w = x.shape - qkv = self.to_qkv(x) - q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) - k = k.softmax(dim=-1) - context = torch.einsum('bhdn,bhen->bhde', k, v) - out = torch.einsum('bhde,bhdn->bhen', context, q) - out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) - return self.to_out(out) - - -class SpatialSelfAttention(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.in_channels = in_channels - - self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.k = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.v = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - - def forward(self, x): - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) - - # compute attention - b,c,h,w = q.shape - q = rearrange(q, 'b c h w -> b (h w) c') - k = rearrange(k, 'b c h w -> b c (h w)') - w_ = torch.einsum('bij,bjk->bik', q, k) - - w_ = w_ * (int(c)**(-0.5)) - w_ = torch.nn.functional.softmax(w_, dim=2) - - # attend to values - v = rearrange(v, 'b c h w -> b c (h w)') - w_ = rearrange(w_, 'b i j -> b j i') - h_ = torch.einsum('bij,bjk->bik', v, w_) - h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) - h_ = self.proj_out(h_) - - return x+h_ - - + def forward(self, hidden_states): + hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) + return hidden_states * F.gelu(gate) +''' class CrossAttention(nn.Module): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): super().__init__() @@ -172,48 +384,45 @@ class CrossAttention(nn.Module): self.cross_attention_callback = None - def einsum_op_compvis(self, q, k, v): - s = einsum('b i d, b j d -> b i j', q, k) - s = s.softmax(dim=-1, dtype=s.dtype) - return einsum('b i j, b j d -> b i d', s, v) - - def einsum_op_slice_0(self, q, k, v, slice_size): + def einsum_op_slice_dim0(self, q, k, v, slice_size, callback): r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) for i in range(0, q.shape[0], slice_size): end = i + slice_size - r[i:end] = self.einsum_op_compvis(q[i:end], k[i:end], v[i:end]) + r[i:end] = callback(q[i:end], k[i:end], v[i:end], offset=i) return r - def einsum_op_slice_1(self, q, k, v, slice_size): + def einsum_op_slice_dim1(self, q, k, v, slice_size, callback): r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) for i in range(0, q.shape[1], slice_size): end = i + slice_size - r[:, i:end] = self.einsum_op_compvis(q[:, i:end], k, v) + r[:, i:end] = callback(q[:, i:end], k, v, offset=i) return r - def einsum_op_mps_v1(self, q, k, v): + def einsum_op_mps_v1(self, q, k, v, callback): if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096 - return self.einsum_op_compvis(q, k, v) + return callback(q, k, v) else: slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) - return self.einsum_op_slice_1(q, k, v, slice_size) + return self.einsum_op_slice_dim1(q, k, v, slice_size, callback) - def einsum_op_mps_v2(self, q, k, v): + def einsum_op_mps_v2(self, q, k, v, callback): if self.mem_total_gb > 8 and q.shape[1] <= 4096: - return self.einsum_op_compvis(q, k, v) + return callback(q, k, v, offset=0) else: - return self.einsum_op_slice_0(q, k, v, 1) + return self.einsum_op_slice_dim0(q, k, v, 1, callback) - def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb): + def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb, callback): size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20) if size_mb <= max_tensor_mb: - return self.einsum_op_compvis(q, k, v) + return callback(q, k, v, offset=0) div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length() if div <= q.shape[0]: - return self.einsum_op_slice_0(q, k, v, q.shape[0] // div) - return self.einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1)) + print("warning: untested call to einsum_op_slice_dim0") + return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div, callback) + print("warning: untested call to einsum_op_slice_dim1") + return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1), callback) - def einsum_op_cuda(self, q, k, v): + def einsum_op_cuda(self, q, k, v, callback): stats = torch.cuda.memory_stats(q.device) mem_active = stats['active_bytes.all.current'] mem_reserved = stats['reserved_bytes.all.current'] @@ -221,20 +430,26 @@ class CrossAttention(nn.Module): mem_free_torch = mem_reserved - mem_active mem_free_total = mem_free_cuda + mem_free_torch # Divide factor of safety as there's copying and fragmentation - return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20)) + return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20), callback) + + def get_attention_mem_efficient(self, q, k, v, callback): + """ + Calculate attention by slicing q, k, and v for memory efficiency then calling + callback(q, k, v, offset=offset) + multiple times if necessary. The offset argument is something + """ - def einsum_op(self, q, k, v): if q.device.type == 'cuda': - return self.einsum_op_cuda(q, k, v) + return self.einsum_op_cuda(q, k, v, callback) if q.device.type == 'mps': if self.mem_total_gb >= 32: - return self.einsum_op_mps_v1(q, k, v) - return self.einsum_op_mps_v2(q, k, v) + return self.einsum_op_mps_v1(q, k, v, callback) + return self.einsum_op_mps_v2(q, k, v, callback) # Smaller slices are faster due to L2/L3/SLC caches. # Tested on i7 with 8MB L3 cache. - return self.einsum_op_tensor_mem(q, k, v, 32) + return self.einsum_op_tensor_mem(q, k, v, 32, callback) def forward(self, x, context=None, mask=None): h = self.heads @@ -247,14 +462,24 @@ class CrossAttention(nn.Module): q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - if self.cross_attention_callback is not None: - r = self.cross_attention_callback(q, k, v) - else: - r = self.einsum_op(q, k, v) + def default_attention_calculator(q, k, v, **kwargs): + # calculate attention scores + attention_scores = einsum('b i d, b j d -> b i j', q, k) + # calculate attenion slice by taking the best scores for each latent pixel + attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype) + return einsum('b i j, b j d -> b i d', attention_slice, v) + + attention_calculator = \ + self.custom_attention_calculator if self.custom_attention_calculator is not None \ + else default_attention_calculator + + r = self.get_attention_mem_efficient(q, k, v, attention_calculator) + hidden_states = rearrange(r, '(b h) n d -> b n (h d)', h=h) return self.to_out(hidden_states) + class BasicTransformerBlock(nn.Module): def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): super().__init__() @@ -322,3 +547,4 @@ class SpatialTransformer(nn.Module): x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) x = self.proj_out(x) return x + x_in +''' \ No newline at end of file diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py index 739710d006..73218d36f8 100644 --- a/ldm/modules/diffusionmodules/model.py +++ b/ldm/modules/diffusionmodules/model.py @@ -8,7 +8,7 @@ import numpy as np from einops import rearrange from ldm.util import instantiate_from_config -from ldm.modules.attention import LinearAttention +#from ldm.modules.attention import LinearAttention import psutil @@ -151,10 +151,10 @@ class ResnetBlock(nn.Module): return x + h -class LinAttnBlock(LinearAttention): - """to match AttnBlock usage""" - def __init__(self, in_channels): - super().__init__(dim=in_channels, heads=1, dim_head=in_channels) +#class LinAttnBlock(LinearAttention): +# """to match AttnBlock usage""" +# def __init__(self, in_channels): +# super().__init__(dim=in_channels, heads=1, dim_head=in_channels) class AttnBlock(nn.Module): diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index 426fccced3..12ef737134 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -449,7 +449,7 @@ class FrozenCLIPEmbedder(AbstractEncoder): tokens = batch_encoding['input_ids'].to(self.device) z = self.transformer(input_ids=tokens, **kwargs) - return z + return z, tokens def encode(self, text, **kwargs): return self(text, **kwargs) From 37a204324b85b85debdada6216a21ce9735e673f Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Tue, 18 Oct 2022 01:54:30 +0200 Subject: [PATCH 06/54] go back to using InvokeAI attention --- ldm/models/diffusion/cross_attention.py | 25 +- ldm/models/diffusion/ksampler.py | 6 +- ldm/modules/attention.py | 513 +++++++----------------- 3 files changed, 172 insertions(+), 372 deletions(-) diff --git a/ldm/models/diffusion/cross_attention.py b/ldm/models/diffusion/cross_attention.py index d829162f35..bcc6c8cc94 100644 --- a/ldm/models/diffusion/cross_attention.py +++ b/ldm/models/diffusion/cross_attention.py @@ -75,11 +75,12 @@ class CrossAttentionControl: @classmethod def inject_attention_functions(cls, unet): # ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276 + def new_attention(self, query, key, value): - # TODO: use baddbmm for better performance - attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale - attn_slice = attention_scores.softmax(dim=-1) - # compute attention output + + attention_scores = torch.functional.einsum('b i d, b j d -> b i j', query, key) + # calculate attention slice by taking the best scores for each latent pixel + attn_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype) if self.use_last_attn_slice: if self.last_attn_slice_mask is not None: @@ -100,13 +101,12 @@ class CrossAttentionControl: attn_slice = attn_slice * self.last_attn_slice_weights self.use_last_attn_weights = False - hidden_states = torch.matmul(attn_slice, value) - # reshape hidden_states - hidden_states = self.reshape_batch_dim_to_heads(hidden_states) - return hidden_states + return torch.functional.einsum('b i j, b j d -> b i d', attn_slice, value) def new_sliced_attention(self, query, key, value, sequence_length, dim): + raise NotImplementedError("not tested yet") + batch_size_attention = query.shape[0] hidden_states = torch.zeros( (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype @@ -146,6 +146,12 @@ class CrossAttentionControl: hidden_states = self.reshape_batch_dim_to_heads(hidden_states) return hidden_states + def select_attention_func(module, q, k, v, dim, offset, slice_size): + if dim == 0 or dim == 1: + return new_sliced_attention(module, q, k, v, sequence_length=slice_size, dim=dim) + else: + return new_attention(module, q, k, v) + for name, module in unet.named_modules(): module_name = type(module).__name__ if module_name == "CrossAttention": @@ -153,8 +159,7 @@ class CrossAttentionControl: module.use_last_attn_slice = False module.use_last_attn_weights = False module.save_last_attn_slice = False - module._sliced_attention = new_sliced_attention.__get__(module, type(module)) - module._attention = new_attention.__get__(module, type(module)) + module.set_custom_attention_calculator(select_attention_func) # original code below diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index e5d521f33f..f5af25fc04 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -48,21 +48,21 @@ class CFGDenoiser(nn.Module): def forward(self, x, sigma, uncond, cond, cond_scale): - print('generating unconditioned latents') + #rint('generating unconditioned latents') unconditioned_latents = self.inner_model(x, sigma, cond=uncond) # process x using the original prompt, saving the attention maps if required if self.edited_conditioning is not None: # this is automatically toggled off after the model forward() CrossAttentionControl.request_save_attention_maps(self.inner_model) - print('generating conditioned latents') + #print('generating conditioned latents') conditioned_latents = self.inner_model(x, sigma, cond=cond) if self.edited_conditioning is not None: # process x again, using the saved attention maps but the new conditioning # this is automatically toggled off after the model forward() CrossAttentionControl.request_apply_saved_attention_maps(self.inner_model) - print('generating edited conditioned latents') + #print('generating edited conditioned latents') conditioned_latents = self.inner_model(x, sigma, cond=self.edited_conditioning) if self.warmup < self.warmup_max: diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index d00b95b1af..1ee0795fdd 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -1,367 +1,158 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +from inspect import isfunction import math -from typing import Optional +from typing import Callable import torch import torch.nn.functional as F -from torch import nn +from torch import nn, einsum +from einops import rearrange, repeat + +from ldm.modules.diffusionmodules.util import checkpoint + +import psutil + +def exists(val): + return val is not None -class AttentionBlock(nn.Module): - """ - An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted - to the N-d case. - https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. - Uses three q, k, v linear layers to compute attention. - - Parameters: - channels (:obj:`int`): The number of channels in the input and output. - num_head_channels (:obj:`int`, *optional*): - The number of channels in each head. If None, then `num_heads` = 1. - num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm. - rescale_output_factor (:obj:`float`, *optional*, defaults to 1.0): The factor to rescale the output by. - eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm. - """ - - def __init__( - self, - channels: int, - num_head_channels: Optional[int] = None, - num_groups: int = 32, - rescale_output_factor: float = 1.0, - eps: float = 1e-5, - ): - super().__init__() - self.channels = channels - - self.num_heads = channels // num_head_channels if num_head_channels is not None else 1 - self.num_head_size = num_head_channels - self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True) - - # define q,k,v as linear layers - self.query = nn.Linear(channels, channels) - self.key = nn.Linear(channels, channels) - self.value = nn.Linear(channels, channels) - - self.rescale_output_factor = rescale_output_factor - self.proj_attn = nn.Linear(channels, channels, 1) - - def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor: - new_projection_shape = projection.size()[:-1] + (self.num_heads, -1) - # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D) - new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3) - return new_projection - - def forward(self, hidden_states): - residual = hidden_states - batch, channel, height, width = hidden_states.shape - - # norm - hidden_states = self.group_norm(hidden_states) - - hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) - - # proj to q, k, v - query_proj = self.query(hidden_states) - key_proj = self.key(hidden_states) - value_proj = self.value(hidden_states) - - # transpose - query_states = self.transpose_for_scores(query_proj) - key_states = self.transpose_for_scores(key_proj) - value_states = self.transpose_for_scores(value_proj) - - # get scores - scale = 1 / math.sqrt(math.sqrt(self.channels / self.num_heads)) - attention_scores = torch.matmul(query_states * scale, key_states.transpose(-1, -2) * scale) # TODO: use baddmm - attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype) - - # compute attention output - hidden_states = torch.matmul(attention_probs, value_states) - - hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous() - new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,) - hidden_states = hidden_states.view(new_hidden_states_shape) - - # compute next hidden_states - hidden_states = self.proj_attn(hidden_states) - hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) - - # res connect and rescale - hidden_states = (hidden_states + residual) / self.rescale_output_factor - return hidden_states +def uniq(arr): + return{el: True for el in arr}.keys() -class SpatialTransformer(nn.Module): - """ - Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply - standard transformer action. Finally, reshape to image. - - Parameters: - in_channels (:obj:`int`): The number of channels in the input and output. - n_heads (:obj:`int`): The number of heads to use for multi-head attention. - d_head (:obj:`int`): The number of channels in each head. - depth (:obj:`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. - dropout (:obj:`float`, *optional*, defaults to 0.1): The dropout probability to use. - context_dim (:obj:`int`, *optional*): The number of context dimensions to use. - """ - - def __init__( - self, - in_channels: int, - n_heads: int, - d_head: int, - depth: int = 1, - dropout: float = 0.0, - num_groups: int = 32, - context_dim: Optional[int] = None, - ): - super().__init__() - self.n_heads = n_heads - self.d_head = d_head - self.in_channels = in_channels - inner_dim = n_heads * d_head - self.norm = torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) - - self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) - - self.transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) - for d in range(depth) - ] - ) - - self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) - - def _set_attention_slice(self, slice_size): - for block in self.transformer_blocks: - block._set_attention_slice(slice_size) - - def forward(self, hidden_states, context=None): - # note: if no context is given, cross-attention defaults to self-attention - batch, channel, height, weight = hidden_states.shape - residual = hidden_states - hidden_states = self.norm(hidden_states) - hidden_states = self.proj_in(hidden_states) - inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) - for block in self.transformer_blocks: - hidden_states = block(hidden_states, context=context) - hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2) - hidden_states = self.proj_out(hidden_states) - return hidden_states + residual +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d -class BasicTransformerBlock(nn.Module): - r""" - A basic Transformer block. - - Parameters: - dim (:obj:`int`): The number of channels in the input and output. - n_heads (:obj:`int`): The number of heads to use for multi-head attention. - d_head (:obj:`int`): The number of channels in each head. - dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. - context_dim (:obj:`int`, *optional*): The size of the context vector for cross attention. - gated_ff (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use a gated feed-forward network. - checkpoint (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use checkpointing. - """ - - def __init__( - self, - dim: int, - n_heads: int, - d_head: int, - dropout=0.0, - context_dim: Optional[int] = None, - gated_ff: bool = True, - checkpoint: bool = True, - ): - super().__init__() - self.attn1 = CrossAttention( - query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout - ) # is a self-attention - self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) - self.attn2 = CrossAttention( - query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout - ) # is self-attn if context is none - self.norm1 = nn.LayerNorm(dim) - self.norm2 = nn.LayerNorm(dim) - self.norm3 = nn.LayerNorm(dim) - self.checkpoint = checkpoint - - def _set_attention_slice(self, slice_size): - self.attn1._slice_size = slice_size - self.attn2._slice_size = slice_size - - def forward(self, hidden_states, context=None): - hidden_states = hidden_states.contiguous() if hidden_states.device.type == "mps" else hidden_states - hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states - hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states - hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states - return hidden_states +def max_neg_value(t): + return -torch.finfo(t.dtype).max -class CrossAttention(nn.Module): - r""" - A cross attention layer. - - Parameters: - query_dim (:obj:`int`): The number of channels in the query. - context_dim (:obj:`int`, *optional*): - The number of channels in the context. If not given, defaults to `query_dim`. - heads (:obj:`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. - dim_head (:obj:`int`, *optional*, defaults to 64): The number of channels in each head. - dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. - """ - - def __init__( - self, query_dim: int, context_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: int = 0.0 - ): - super().__init__() - inner_dim = dim_head * heads - context_dim = context_dim if context_dim is not None else query_dim - - self.scale = dim_head**-0.5 - self.heads = heads - # for slice_size > 0 the attention score computation - # is split across the batch axis to save memory - # You can set slice_size with `set_attention_slice` - self._slice_size = None - - self.to_q = nn.Linear(query_dim, inner_dim, bias=False) - self.to_k = nn.Linear(context_dim, inner_dim, bias=False) - self.to_v = nn.Linear(context_dim, inner_dim, bias=False) - - self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) - - def reshape_heads_to_batch_dim(self, tensor): - batch_size, seq_len, dim = tensor.shape - head_size = self.heads - tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) - tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) - return tensor - - def reshape_batch_dim_to_heads(self, tensor): - batch_size, seq_len, dim = tensor.shape - head_size = self.heads - tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) - tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) - return tensor - - def forward(self, hidden_states, context=None, mask=None): - batch_size, sequence_length, _ = hidden_states.shape - - query = self.to_q(hidden_states) - context = context if context is not None else hidden_states - key = self.to_k(context) - value = self.to_v(context) - - dim = query.shape[-1] - - query = self.reshape_heads_to_batch_dim(query) - key = self.reshape_heads_to_batch_dim(key) - value = self.reshape_heads_to_batch_dim(value) - - # TODO(PVP) - mask is currently never used. Remember to re-implement when used - - # attention, what we cannot get enough of - - if self._slice_size is None or query.shape[0] // self._slice_size == 1: - hidden_states = self._attention(query, key, value) - else: - hidden_states = self._sliced_attention(query, key, value, sequence_length, dim) - - return self.to_out(hidden_states) - - def _attention(self, query, key, value): - # TODO: use baddbmm for better performance - attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale - attention_probs = attention_scores.softmax(dim=-1) - # compute attention output - hidden_states = torch.matmul(attention_probs, value) - # reshape hidden_states - hidden_states = self.reshape_batch_dim_to_heads(hidden_states) - return hidden_states - - def _sliced_attention(self, query, key, value, sequence_length, dim): - batch_size_attention = query.shape[0] - hidden_states = torch.zeros( - (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype - ) - slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] - for i in range(hidden_states.shape[0] // slice_size): - start_idx = i * slice_size - end_idx = (i + 1) * slice_size - attn_slice = ( - torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale - ) # TODO: use baddbmm for better performance - attn_slice = attn_slice.softmax(dim=-1) - attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx]) - - hidden_states[start_idx:end_idx] = attn_slice - - # reshape hidden_states - hidden_states = self.reshape_batch_dim_to_heads(hidden_states) - return hidden_states - - -class FeedForward(nn.Module): - r""" - A feed-forward layer. - - Parameters: - dim (:obj:`int`): The number of channels in the input. - dim_out (:obj:`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. - mult (:obj:`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. - glu (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use GLU activation. - dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use. - """ - - def __init__( - self, dim: int, dim_out: Optional[int] = None, mult: int = 4, glu: bool = False, dropout: float = 0.0 - ): - super().__init__() - inner_dim = int(dim * mult) - dim_out = dim_out if dim_out is not None else dim - project_in = GEGLU(dim, inner_dim) - - self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) - - def forward(self, hidden_states): - return self.net(hidden_states) +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor # feedforward class GEGLU(nn.Module): - r""" - A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. - - Parameters: - dim_in (:obj:`int`): The number of channels in the input. - dim_out (:obj:`int`): The number of channels in the output. - """ - - def __init__(self, dim_in: int, dim_out: int): + def __init__(self, dim_in, dim_out): super().__init__() self.proj = nn.Linear(dim_in, dim_out * 2) - def forward(self, hidden_states): - hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) - return hidden_states * F.gelu(gate) -''' + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum('bhdn,bhen->bhde', k, v) + out = torch.einsum('bhde,bhdn->bhen', context, q) + out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = rearrange(q, 'b c h w -> b (h w) c') + k = rearrange(k, 'b c h w -> b c (h w)') + w_ = torch.einsum('bij,bjk->bik', q, k) + + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, 'b c h w -> b c (h w)') + w_ = rearrange(w_, 'b i j -> b j i') + h_ = torch.einsum('bij,bjk->bik', v, w_) + h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) + h_ = self.proj_out(h_) + + return x+h_ + + + class CrossAttention(nn.Module): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): super().__init__() @@ -382,39 +173,51 @@ class CrossAttention(nn.Module): self.mem_total_gb = psutil.virtual_memory().total // (1 << 30) - self.cross_attention_callback = None + self.custom_attention_calculator = None + + def set_custom_attention_calculator(self, callback:Callable[[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int], torch.Tensor]): + ''' + Set custom attention calculator to be called when attention is calculated + :param callback: Callback, with args q, k, v, dim, offset, slice_size, which returns attention info. + q, k, v are as regular attention calculator. + dim is -1 if the call is non-sliced, or 0 or 1 for dimension-0 or dimension-1 slicing. + If dim is >= 0, offset and slice_size specify the slice start and length. + Pass None to use the default attention calculation. + :return: + ''' + self.custom_attention_calculator = callback def einsum_op_slice_dim0(self, q, k, v, slice_size, callback): r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) for i in range(0, q.shape[0], slice_size): end = i + slice_size - r[i:end] = callback(q[i:end], k[i:end], v[i:end], offset=i) + r[i:end] = callback(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size) return r def einsum_op_slice_dim1(self, q, k, v, slice_size, callback): r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) for i in range(0, q.shape[1], slice_size): end = i + slice_size - r[:, i:end] = callback(q[:, i:end], k, v, offset=i) + r[:, i:end] = callback(self, q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size) return r def einsum_op_mps_v1(self, q, k, v, callback): if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096 - return callback(q, k, v) + return callback(self, q, k, v, -1, 0, 0) else: slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) return self.einsum_op_slice_dim1(q, k, v, slice_size, callback) def einsum_op_mps_v2(self, q, k, v, callback): if self.mem_total_gb > 8 and q.shape[1] <= 4096: - return callback(q, k, v, offset=0) + return callback(self, q, k, v, -1, 0, 0) else: return self.einsum_op_slice_dim0(q, k, v, 1, callback) def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb, callback): size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20) if size_mb <= max_tensor_mb: - return callback(q, k, v, offset=0) + return callback(self, q, k, v, offset=0) div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length() if div <= q.shape[0]: print("warning: untested call to einsum_op_slice_dim0") @@ -433,12 +236,6 @@ class CrossAttention(nn.Module): return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20), callback) def get_attention_mem_efficient(self, q, k, v, callback): - """ - Calculate attention by slicing q, k, and v for memory efficiency then calling - callback(q, k, v, offset=offset) - multiple times if necessary. The offset argument is something - """ - if q.device.type == 'cuda': return self.einsum_op_cuda(q, k, v, callback) @@ -479,7 +276,6 @@ class CrossAttention(nn.Module): return self.to_out(hidden_states) - class BasicTransformerBlock(nn.Module): def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): super().__init__() @@ -547,4 +343,3 @@ class SpatialTransformer(nn.Module): x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) x = self.proj_out(x) return x + x_in -''' \ No newline at end of file From 056cb0d8a8588c1f0af620df2aa50ad3efc93ec3 Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Tue, 18 Oct 2022 11:48:33 +0200 Subject: [PATCH 07/54] sliced cross-attention wrangler works --- ldm/models/diffusion/cross_attention.py | 140 +++++++++++++----------- ldm/models/diffusion/ksampler.py | 3 + ldm/modules/attention.py | 83 +++++++------- 3 files changed, 123 insertions(+), 103 deletions(-) diff --git a/ldm/models/diffusion/cross_attention.py b/ldm/models/diffusion/cross_attention.py index bcc6c8cc94..d2c9d0fb02 100644 --- a/ldm/models/diffusion/cross_attention.py +++ b/ldm/models/diffusion/cross_attention.py @@ -56,6 +56,13 @@ class CrossAttentionControl: return [module for name, module in model.named_modules() if type(module).__name__ == "CrossAttention" and which_attn in name] + @classmethod + def clear_requests(cls, model): + self_attention_modules = cls.get_attention_modules(model, cls.AttentionType.SELF) + tokens_attention_modules = cls.get_attention_modules(model, cls.AttentionType.TOKENS) + for m in self_attention_modules+tokens_attention_modules: + m.save_last_attn_slice = False + m.use_last_attn_slice = False @classmethod def request_save_attention_maps(cls, model): @@ -76,81 +83,84 @@ class CrossAttentionControl: def inject_attention_functions(cls, unet): # ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276 - def new_attention(self, query, key, value): + def attention_slice_wrangler(self, attention_scores, suggested_attention_slice, dim, offset, slice_size): + + attn_slice = suggested_attention_slice + if dim is not None: + start = offset + end = start+slice_size + #print(f"in wrangler, sliced dim {dim} {start}-{end}, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}") + #else: + # print(f"in wrangler, whole, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}") - attention_scores = torch.functional.einsum('b i d, b j d -> b i j', query, key) - # calculate attention slice by taking the best scores for each latent pixel - attn_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype) if self.use_last_attn_slice: + this_attn_slice = attn_slice if self.last_attn_slice_mask is not None: - base_attn_slice = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices) + # indices and mask operate on dim=2, no need to slice + base_attn_slice_full = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices) base_attn_slice_mask = self.last_attn_slice_mask - this_attn_slice_mask = 1 - self.last_attn_slice_mask - attn_slice = attn_slice * this_attn_slice_mask + base_attn_slice * base_attn_slice_mask - else: - attn_slice = self.last_attn_slice + if dim is None: + base_attn_slice = base_attn_slice_full + #print("using whole base slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape) + elif dim == 0: + base_attn_slice = base_attn_slice_full[start:end] + #print("using base dim 0 slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape) + elif dim == 1: + base_attn_slice = base_attn_slice_full[:, start:end] + #print("using base dim 1 slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape) - self.use_last_attn_slice = False + attn_slice = this_attn_slice * (1 - base_attn_slice_mask) + \ + base_attn_slice * base_attn_slice_mask + else: + if dim is None: + attn_slice = self.last_attn_slice + #print("took whole slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape) + elif dim == 0: + attn_slice = self.last_attn_slice[start:end] + #print("took dim 0 slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape) + elif dim == 1: + attn_slice = self.last_attn_slice[:, start:end] + #print("took dim 1 slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape) if self.save_last_attn_slice: - self.last_attn_slice = attn_slice - self.save_last_attn_slice = False + if dim is None: + self.last_attn_slice = attn_slice + elif dim == 0: + # dynamically grow last_attn_slice if needed + if self.last_attn_slice is None: + self.last_attn_slice = attn_slice + #print("no last_attn_slice: shape now", self.last_attn_slice.shape) + elif self.last_attn_slice.shape[0] == start: + self.last_attn_slice = torch.cat([self.last_attn_slice, attn_slice], dim=0) + assert(self.last_attn_slice.shape[0] == end) + #print("last_attn_slice too small, appended dim 0 shape", attn_slice.shape, ", shape now", self.last_attn_slice.shape) + else: + # no need to grow + self.last_attn_slice[start:end] = attn_slice + #print("last_attn_slice shape is fine, setting dim 0 shape", attn_slice.shape, ", shape now", self.last_attn_slice.shape) + + elif dim == 1: + # dynamically grow last_attn_slice if needed + if self.last_attn_slice is None: + self.last_attn_slice = attn_slice + elif self.last_attn_slice.shape[1] == start: + self.last_attn_slice = torch.cat([self.last_attn_slice, attn_slice], dim=1) + assert(self.last_attn_slice.shape[1] == end) + else: + # no need to grow + self.last_attn_slice[:, start:end] = attn_slice if self.use_last_attn_weights and self.last_attn_slice_weights is not None: - attn_slice = attn_slice * self.last_attn_slice_weights - self.use_last_attn_weights = False + if dim is None: + weights = self.last_attn_slice_weights + elif dim == 0: + weights = self.last_attn_slice_weights[start:end] + elif dim == 1: + weights = self.last_attn_slice_weights[:, start:end] + attn_slice = attn_slice * weights - return torch.functional.einsum('b i j, b j d -> b i d', attn_slice, value) - - def new_sliced_attention(self, query, key, value, sequence_length, dim): - - raise NotImplementedError("not tested yet") - - batch_size_attention = query.shape[0] - hidden_states = torch.zeros( - (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype - ) - slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] - for i in range(hidden_states.shape[0] // slice_size): - start_idx = i * slice_size - end_idx = (i + 1) * slice_size - attn_slice = ( - torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale - ) # TODO: use baddbmm for better performance - attn_slice = attn_slice.softmax(dim=-1) - - if self.use_last_attn_slice: - if self.last_attn_slice_mask is not None: - new_attn_slice = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices) - attn_slice = attn_slice * ( - 1 - self.last_attn_slice_mask) + new_attn_slice * self.last_attn_slice_mask - else: - attn_slice = self.last_attn_slice - - self.use_last_attn_slice = False - - if self.save_last_attn_slice: - self.last_attn_slice = attn_slice - self.save_last_attn_slice = False - - if self.use_last_attn_weights and self.last_attn_slice_weights is not None: - attn_slice = attn_slice * self.last_attn_slice_weights - self.use_last_attn_weights = False - - attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx]) - - hidden_states[start_idx:end_idx] = attn_slice - - # reshape hidden_states - hidden_states = self.reshape_batch_dim_to_heads(hidden_states) - return hidden_states - - def select_attention_func(module, q, k, v, dim, offset, slice_size): - if dim == 0 or dim == 1: - return new_sliced_attention(module, q, k, v, sequence_length=slice_size, dim=dim) - else: - return new_attention(module, q, k, v) + return attn_slice for name, module in unet.named_modules(): module_name = type(module).__name__ @@ -159,7 +169,7 @@ class CrossAttentionControl: module.use_last_attn_slice = False module.use_last_attn_weights = False module.save_last_attn_slice = False - module.set_custom_attention_calculator(select_attention_func) + module.set_attention_slice_wrangler(attention_slice_wrangler) # original code below diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index f5af25fc04..3b5a59a981 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -48,6 +48,8 @@ class CFGDenoiser(nn.Module): def forward(self, x, sigma, uncond, cond, cond_scale): + CrossAttentionControl.clear_requests(self.inner_model) + #rint('generating unconditioned latents') unconditioned_latents = self.inner_model(x, sigma, cond=uncond) @@ -61,6 +63,7 @@ class CFGDenoiser(nn.Module): if self.edited_conditioning is not None: # process x again, using the saved attention maps but the new conditioning # this is automatically toggled off after the model forward() + CrossAttentionControl.clear_requests(self.inner_model) CrossAttentionControl.request_apply_saved_attention_maps(self.inner_model) #print('generating edited conditioned latents') conditioned_latents = self.inner_model(x, sigma, cond=self.edited_conditioning) diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index 1ee0795fdd..8d160f004b 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -173,59 +173,75 @@ class CrossAttention(nn.Module): self.mem_total_gb = psutil.virtual_memory().total // (1 << 30) - self.custom_attention_calculator = None + self.attention_slice_wrangler = None - def set_custom_attention_calculator(self, callback:Callable[[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int], torch.Tensor]): + def set_attention_slice_wrangler(self, wrangler:Callable[[nn.Module, torch.Tensor, torch.Tensor, int, int, int], torch.Tensor]): ''' Set custom attention calculator to be called when attention is calculated - :param callback: Callback, with args q, k, v, dim, offset, slice_size, which returns attention info. - q, k, v are as regular attention calculator. + :param wrangler: Callback, with args (self, attention_scores, suggested_attention_slice, dim, offset, slice_size), + which returns either the suggested_attention_slice or an adjusted equivalent. + self is the current CrossAttention module for which the callback is being invoked. + attention_scores are the scores for attention + suggested_attention_slice is a softmax(dim=-1) over attention_scores dim is -1 if the call is non-sliced, or 0 or 1 for dimension-0 or dimension-1 slicing. If dim is >= 0, offset and slice_size specify the slice start and length. + Pass None to use the default attention calculation. :return: ''' - self.custom_attention_calculator = callback + self.attention_slice_wrangler = wrangler - def einsum_op_slice_dim0(self, q, k, v, slice_size, callback): + def einsum_lowest_level(self, q, k, v, dim, offset, slice_size): + # calculate attention scores + attention_scores = einsum('b i d, b j d -> b i j', q, k) + # calculate attenion slice by taking the best scores for each latent pixel + default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype) + if self.attention_slice_wrangler is not None: + attention_slice = self.attention_slice_wrangler(self, attention_scores, default_attention_slice, dim, offset, slice_size) + else: + attention_slice = default_attention_slice + + return einsum('b i j, b j d -> b i d', attention_slice, v) + + def einsum_op_slice_dim0(self, q, k, v, slice_size): r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) for i in range(0, q.shape[0], slice_size): end = i + slice_size - r[i:end] = callback(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size) + r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size) return r - def einsum_op_slice_dim1(self, q, k, v, slice_size, callback): + def einsum_op_slice_dim1(self, q, k, v, slice_size): r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) for i in range(0, q.shape[1], slice_size): end = i + slice_size - r[:, i:end] = callback(self, q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size) + r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size) return r - def einsum_op_mps_v1(self, q, k, v, callback): + def einsum_op_mps_v1(self, q, k, v): if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096 - return callback(self, q, k, v, -1, 0, 0) + return self.einsum_lowest_level(q, k, v, None, None, None) else: slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) - return self.einsum_op_slice_dim1(q, k, v, slice_size, callback) + return self.einsum_op_slice_dim1(q, k, v, slice_size) - def einsum_op_mps_v2(self, q, k, v, callback): + def einsum_op_mps_v2(self, q, k, v): if self.mem_total_gb > 8 and q.shape[1] <= 4096: - return callback(self, q, k, v, -1, 0, 0) + return self.einsum_lowest_level(q, k, v, None, None, None) else: - return self.einsum_op_slice_dim0(q, k, v, 1, callback) + return self.einsum_op_slice_dim0(q, k, v, 1) - def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb, callback): + def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb): size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20) if size_mb <= max_tensor_mb: - return callback(self, q, k, v, offset=0) + return self.einsum_lowest_level(q, k, v, None, None, None) div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length() if div <= q.shape[0]: print("warning: untested call to einsum_op_slice_dim0") - return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div, callback) + return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div) print("warning: untested call to einsum_op_slice_dim1") - return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1), callback) + return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1)) - def einsum_op_cuda(self, q, k, v, callback): + def einsum_op_cuda(self, q, k, v): stats = torch.cuda.memory_stats(q.device) mem_active = stats['active_bytes.all.current'] mem_reserved = stats['reserved_bytes.all.current'] @@ -233,20 +249,20 @@ class CrossAttention(nn.Module): mem_free_torch = mem_reserved - mem_active mem_free_total = mem_free_cuda + mem_free_torch # Divide factor of safety as there's copying and fragmentation - return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20), callback) + return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20)) - def get_attention_mem_efficient(self, q, k, v, callback): + def get_attention_mem_efficient(self, q, k, v): if q.device.type == 'cuda': - return self.einsum_op_cuda(q, k, v, callback) + return self.einsum_op_cuda(q, k, v) if q.device.type == 'mps': if self.mem_total_gb >= 32: - return self.einsum_op_mps_v1(q, k, v, callback) - return self.einsum_op_mps_v2(q, k, v, callback) + return self.einsum_op_mps_v1(q, k, v) + return self.einsum_op_mps_v2(q, k, v) # Smaller slices are faster due to L2/L3/SLC caches. # Tested on i7 with 8MB L3 cache. - return self.einsum_op_tensor_mem(q, k, v, 32, callback) + return self.einsum_op_tensor_mem(q, k, v, 32) def forward(self, x, context=None, mask=None): h = self.heads @@ -259,23 +275,14 @@ class CrossAttention(nn.Module): q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - def default_attention_calculator(q, k, v, **kwargs): - # calculate attention scores - attention_scores = einsum('b i d, b j d -> b i j', q, k) - # calculate attenion slice by taking the best scores for each latent pixel - attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype) - return einsum('b i j, b j d -> b i d', attention_slice, v) - - attention_calculator = \ - self.custom_attention_calculator if self.custom_attention_calculator is not None \ - else default_attention_calculator - - r = self.get_attention_mem_efficient(q, k, v, attention_calculator) + r = self.get_attention_mem_efficient(q, k, v) hidden_states = rearrange(r, '(b h) n d -> b n (h d)', h=h) return self.to_out(hidden_states) + + class BasicTransformerBlock(nn.Module): def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): super().__init__() From 711ffd238f3ed1a241f0df20b21ebab6ca8a7308 Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Tue, 18 Oct 2022 13:52:40 +0200 Subject: [PATCH 08/54] cleanup --- ldm/models/diffusion/cross_attention.py | 69 +++++++++++++++---------- ldm/models/diffusion/ksampler.py | 2 +- 2 files changed, 43 insertions(+), 28 deletions(-) diff --git a/ldm/models/diffusion/cross_attention.py b/ldm/models/diffusion/cross_attention.py index d2c9d0fb02..d5c3eaadf0 100644 --- a/ldm/models/diffusion/cross_attention.py +++ b/ldm/models/diffusion/cross_attention.py @@ -6,14 +6,16 @@ import torch # https://github.com/bloc97/CrossAttentionControl class CrossAttentionControl: - class AttentionType(Enum): - SELF = 1 - TOKENS = 2 + + + @classmethod + def clear_attention_editing(cls, model): + cls.remove_attention_function(model) @classmethod def setup_attention_editing(cls, model, - substitute_conditioning: torch.Tensor = None, - edit_opcodes: list = None): + substitute_conditioning: torch.Tensor, + edit_opcodes: list): """ :param model: The unet model to inject into. :param substitute_conditioning: The "edited" conditioning vector, [Bx77x768] @@ -23,31 +25,34 @@ class CrossAttentionControl: """ # adapted from init_attention_edit - if substitute_conditioning is not None: + device = substitute_conditioning.device - device = substitute_conditioning.device + max_length = model.inner_model.cond_stage_model.max_length + # mask=1 means use base prompt attention, mask=0 means use edited prompt attention + mask = torch.zeros(max_length) + indices_target = torch.arange(max_length, dtype=torch.long) + indices = torch.zeros(max_length, dtype=torch.long) + for name, a0, a1, b0, b1 in edit_opcodes: + if b0 < max_length: + if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0): + # these tokens have not been edited + indices[b0:b1] = indices_target[a0:a1] + mask[b0:b1] = 1 - max_length = model.inner_model.cond_stage_model.max_length - # mask=1 means use base prompt attention, mask=0 means use edited prompt attention - mask = torch.zeros(max_length) - indices_target = torch.arange(max_length, dtype=torch.long) - indices = torch.zeros(max_length, dtype=torch.long) - for name, a0, a1, b0, b1 in edit_opcodes: - if b0 < max_length: - if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0): - # these tokens have not been edited - indices[b0:b1] = indices_target[a0:a1] - mask[b0:b1] = 1 + for m in cls.get_attention_modules(model, cls.AttentionType.SELF): + m.last_attn_slice_mask = None + m.last_attn_slice_indices = None - for m in cls.get_attention_modules(model, cls.AttentionType.SELF): - m.last_attn_slice_mask = None - m.last_attn_slice_indices = None + for m in cls.get_attention_modules(model, cls.AttentionType.TOKENS): + m.last_attn_slice_mask = mask.to(device) + m.last_attn_slice_indices = indices.to(device) - for m in cls.get_attention_modules(model, cls.AttentionType.TOKENS): - m.last_attn_slice_mask = mask.to(device) - m.last_attn_slice_indices = indices.to(device) + cls.inject_attention_function(model) - cls.inject_attention_functions(model) + + class AttentionType(Enum): + SELF = 1 + TOKENS = 2 @classmethod @@ -79,8 +84,9 @@ class CrossAttentionControl: m.use_last_attn_slice = True + @classmethod - def inject_attention_functions(cls, unet): + def inject_attention_function(cls, unet): # ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276 def attention_slice_wrangler(self, attention_scores, suggested_attention_slice, dim, offset, slice_size): @@ -166,11 +172,20 @@ class CrossAttentionControl: module_name = type(module).__name__ if module_name == "CrossAttention": module.last_attn_slice = None - module.use_last_attn_slice = False + module.last_attn_slice_indices = None + module.last_attn_slice_mask = None module.use_last_attn_weights = False + module.use_last_attn_slice = False module.save_last_attn_slice = False module.set_attention_slice_wrangler(attention_slice_wrangler) + @classmethod + def remove_attention_function(cls, unet): + for name, module in unet.named_modules(): + module_name = type(module).__name__ + if module_name == "CrossAttention": + module.set_attention_slice_wrangler(None) + # original code below diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index 3b5a59a981..c8b4823111 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -44,7 +44,7 @@ class CFGDenoiser(nn.Module): CrossAttentionControl.setup_attention_editing(self.inner_model, edited_conditioning, edit_opcodes) else: # pass through the attention func but don't act on it - CrossAttentionControl.setup_attention_editing(self.inner_model) + CrossAttentionControl.clear_attention_editing(self.inner_model) def forward(self, x, sigma, uncond, cond, cond_scale): From 09f62032ec1cfad3e47074856a0da5718251570f Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Tue, 18 Oct 2022 19:49:25 +0200 Subject: [PATCH 09/54] cleanup and clarify comments --- ldm/models/diffusion/cross_attention.py | 4 +++ ldm/models/diffusion/ksampler.py | 40 +++++++++++++++---------- 2 files changed, 29 insertions(+), 15 deletions(-) diff --git a/ldm/models/diffusion/cross_attention.py b/ldm/models/diffusion/cross_attention.py index d5c3eaadf0..71d5995b4a 100644 --- a/ldm/models/diffusion/cross_attention.py +++ b/ldm/models/diffusion/cross_attention.py @@ -74,6 +74,8 @@ class CrossAttentionControl: self_attention_modules = cls.get_attention_modules(model, cls.AttentionType.SELF) tokens_attention_modules = cls.get_attention_modules(model, cls.AttentionType.TOKENS) for m in self_attention_modules+tokens_attention_modules: + # clear out the saved slice in case the outermost dim changes + m.last_attn_slice = None m.save_last_attn_slice = True @classmethod @@ -91,6 +93,8 @@ class CrossAttentionControl: def attention_slice_wrangler(self, attention_scores, suggested_attention_slice, dim, offset, slice_size): + #print("in wrangler with suggested_attention_slice shape", suggested_attention_slice.shape, "dim", dim) + attn_slice = suggested_attention_slice if dim is not None: start = offset diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index c8b4823111..7459e2e7cc 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -50,23 +50,32 @@ class CFGDenoiser(nn.Module): CrossAttentionControl.clear_requests(self.inner_model) - #rint('generating unconditioned latents') - unconditioned_latents = self.inner_model(x, sigma, cond=uncond) + if self.edited_conditioning is None: + # faster batch path + x_twice = torch.cat([x]*2) + sigma_twice = torch.cat([sigma]*2) + both_conditionings = torch.cat([uncond, cond]) + unconditioned_next_x, conditioned_next_x = self.inner_model(x_twice, sigma_twice, cond=both_conditionings).chunk(2) + else: + # slower non-batched path (20% slower on mac MPS) + # We are only interested in using attention maps for conditioned_next_x, but batching them with generation of + # unconditioned_next_x causes attention maps to *also* be saved for the unconditioned_next_x. + # This messes app their application later, due to mismatched shape of dim 0 (16 vs. 8) + # (For the batched invocation the `wrangler` function gets attention tensor with shape[0]=16, + # representing batched uncond + cond, but then when it comes to applying the saved attention, the + # wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.) + # todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well. + unconditioned_next_x = self.inner_model(x, sigma, cond=uncond) - # process x using the original prompt, saving the attention maps if required - if self.edited_conditioning is not None: - # this is automatically toggled off after the model forward() + # process x using the original prompt, saving the attention maps CrossAttentionControl.request_save_attention_maps(self.inner_model) - #print('generating conditioned latents') - conditioned_latents = self.inner_model(x, sigma, cond=cond) - - if self.edited_conditioning is not None: - # process x again, using the saved attention maps but the new conditioning - # this is automatically toggled off after the model forward() + _ = self.inner_model(x, sigma, cond=cond) CrossAttentionControl.clear_requests(self.inner_model) + + # process x again, using the saved attention maps to control where self.edited_conditioning will be applied CrossAttentionControl.request_apply_saved_attention_maps(self.inner_model) - #print('generating edited conditioned latents') - conditioned_latents = self.inner_model(x, sigma, cond=self.edited_conditioning) + conditioned_next_x = self.inner_model(x, sigma, cond=self.edited_conditioning) + CrossAttentionControl.clear_requests(self.inner_model) if self.warmup < self.warmup_max: thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max)) @@ -75,8 +84,9 @@ class CFGDenoiser(nn.Module): thresh = self.threshold if thresh > self.threshold: thresh = self.threshold - delta = (conditioned_latents - unconditioned_latents) - return cfg_apply_threshold(unconditioned_latents + delta * cond_scale, thresh) + # to scale how much effect conditioning has, calculate the changes it does and then scale that + scaled_delta = (conditioned_next_x - unconditioned_next_x) * cond_scale + return cfg_apply_threshold(unconditioned_next_x + scaled_delta, thresh) class KSampler(Sampler): From 54e6a68acbf314ddbcbda450bb0c5d39b39af4a3 Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Tue, 18 Oct 2022 22:09:06 +0200 Subject: [PATCH 10/54] wip bringing cross-attention to PLMS and DDIM --- ldm/invoke/generator/txt2img.py | 4 +- ldm/models/diffusion/cross_attention.py | 52 ++++++++++++++++++++- ldm/models/diffusion/ddim.py | 21 +++++++-- ldm/models/diffusion/ksampler.py | 62 +++++++------------------ ldm/models/diffusion/plms.py | 28 ++++++++--- ldm/models/diffusion/sampler.py | 8 ++-- 6 files changed, 112 insertions(+), 63 deletions(-) diff --git a/ldm/invoke/generator/txt2img.py b/ldm/invoke/generator/txt2img.py index 9f066745f7..669f3d81ff 100644 --- a/ldm/invoke/generator/txt2img.py +++ b/ldm/invoke/generator/txt2img.py @@ -19,7 +19,7 @@ class Txt2Img(Generator): kwargs are 'width' and 'height' """ self.perlin = perlin - uc, c, ec, edit_index_map = conditioning + uc, c, ec, edit_opcodes = conditioning @torch.no_grad() def make_image(x_T): @@ -44,7 +44,7 @@ class Txt2Img(Generator): unconditional_guidance_scale = cfg_scale, unconditional_conditioning = uc, edited_conditioning = ec, - edit_token_index_map = edit_index_map, + conditioning_edit_opcodes = edit_opcodes, eta = ddim_eta, img_callback = step_callback, threshold = threshold, diff --git a/ldm/models/diffusion/cross_attention.py b/ldm/models/diffusion/cross_attention.py index 71d5995b4a..c0760fff47 100644 --- a/ldm/models/diffusion/cross_attention.py +++ b/ldm/models/diffusion/cross_attention.py @@ -2,6 +2,55 @@ from enum import Enum import torch + +class CrossAttentionControllableDiffusionMixin: + + def setup_cross_attention_control_if_appropriate(self, model, edited_conditioning, edit_opcodes): + self.edited_conditioning = edited_conditioning + + if edited_conditioning is not None: + # a cat sitting on a car + CrossAttentionControl.setup_attention_editing(model, edited_conditioning, edit_opcodes) + else: + # pass through the attention func but don't act on it + CrossAttentionControl.clear_attention_editing(model) + + def cleanup_cross_attention_control(self, model): + CrossAttentionControl.clear_attention_editing(model) + + def do_cross_attention_controllable_diffusion_step(self, x, sigma, unconditioning, conditioning, model, model_forward_callback): + + CrossAttentionControl.clear_requests(model) + + if self.edited_conditioning is None: + # faster batched path + x_twice = torch.cat([x]*2) + sigma_twice = torch.cat([sigma]*2) + both_conditionings = torch.cat([unconditioning, conditioning]) + unconditioned_next_x, conditioned_next_x = model_forward_callback(x_twice, sigma_twice, both_conditionings).chunk(2) + else: + # slower non-batched path (20% slower on mac MPS) + # We are only interested in using attention maps for conditioned_next_x, but batching them with generation of + # unconditioned_next_x causes attention maps to *also* be saved for the unconditioned_next_x. + # This messes app their application later, due to mismatched shape of dim 0 (seems to be 16 for batched vs. 8) + # (For the batched invocation the `wrangler` function gets attention tensor with shape[0]=16, + # representing batched uncond + cond, but then when it comes to applying the saved attention, the + # wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.) + # todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well. + unconditioned_next_x = model_forward_callback(x, sigma, unconditioning) + + # process x using the original prompt, saving the attention maps + CrossAttentionControl.request_save_attention_maps(model) + _ = model_forward_callback(x, sigma, cond=conditioning) + CrossAttentionControl.clear_requests(model) + + # process x again, using the saved attention maps to control where self.edited_conditioning will be applied + CrossAttentionControl.request_apply_saved_attention_maps(model) + conditioned_next_x = model_forward_callback(x, sigma, self.edited_conditioning) + CrossAttentionControl.clear_requests(model) + + return unconditioned_next_x, conditioned_next_x + # adapted from bloc97's CrossAttentionControl colab # https://github.com/bloc97/CrossAttentionControl @@ -27,7 +76,8 @@ class CrossAttentionControl: # adapted from init_attention_edit device = substitute_conditioning.device - max_length = model.inner_model.cond_stage_model.max_length + # urgh. should this be hardcoded? + max_length = 77 # mask=1 means use base prompt attention, mask=0 means use edited prompt attention mask = torch.zeros(max_length) indices_target = torch.arange(max_length, dtype=torch.long) diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py index f5dada8627..4980b03c42 100644 --- a/ldm/models/diffusion/ddim.py +++ b/ldm/models/diffusion/ddim.py @@ -5,13 +5,23 @@ import numpy as np from tqdm import tqdm from functools import partial from ldm.invoke.devices import choose_torch_device +from ldm.models.diffusion.cross_attention import CrossAttentionControllableDiffusionMixin from ldm.models.diffusion.sampler import Sampler from ldm.modules.diffusionmodules.util import noise_like -class DDIMSampler(Sampler): +class DDIMSampler(Sampler, CrossAttentionControllableDiffusionMixin): def __init__(self, model, schedule='linear', device=None, **kwargs): super().__init__(model,schedule,model.num_timesteps,device) + def prepare_to_sample(self, t_enc, **kwargs): + super().prepare_to_sample(t_enc, **kwargs) + + edited_conditioning = kwargs.get('edited_conditioning', None) + edit_opcodes = kwargs.get('conditioning_edit_opcodes', None) + + self.setup_cross_attention_control_if_appropriate(self.model, edited_conditioning, edit_opcodes) + + # This is the central routine @torch.no_grad() def p_sample( @@ -37,12 +47,13 @@ class DDIMSampler(Sampler): unconditional_conditioning is None or unconditional_guidance_scale == 1.0 ): + # damian0815 does not think this code path is ever used e_t = self.model.apply_model(x, t, c) else: - x_in = torch.cat([x] * 2) - t_in = torch.cat([t] * 2) - c_in = torch.cat([unconditional_conditioning, c]) - e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + + e_t_uncond, e_t = self.do_cross_attention_controllable_diffusion_step(x, t, unconditional_conditioning, c, self.model, + model_forward_callback=lambda x, sigma, cond: self.model.apply_model(x, sigma, cond)) + e_t = e_t_uncond + unconditional_guidance_scale * ( e_t - e_t_uncond ) diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index 7459e2e7cc..78d5978efe 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -13,7 +13,8 @@ from ldm.modules.diffusionmodules.util import ( noise_like, extract_into_tensor, ) -from ldm.models.diffusion.cross_attention import CrossAttentionControl +from ldm.models.diffusion.cross_attention import CrossAttentionControl, CrossAttentionControllableDiffusionMixin + def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7): if threshold <= 0.0: @@ -29,53 +30,26 @@ def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7): return torch.clamp(result, min=minval, max=maxval) -class CFGDenoiser(nn.Module): - def __init__(self, model, threshold = 0, warmup = 0, edited_conditioning = None, edit_opcodes = None): +class CFGDenoiser(nn.Module, CrossAttentionControllableDiffusionMixin): + def __init__(self, model, threshold = 0, warmup = 0): super().__init__() self.inner_model = model self.threshold = threshold self.warmup_max = warmup self.warmup = max(warmup / 10, 1) - self.edited_conditioning = edited_conditioning + def prepare_to_sample(self, t_enc, **kwargs): + + edited_conditioning = kwargs.get('edited_conditioning', None) + conditioning_edit_opcodes = kwargs.get('conditioning_edit_opcodes', None) + + self.setup_cross_attention_control_if_appropriate(self.model, edited_conditioning, conditioning_edit_opcodes) - if edited_conditioning is not None: - # a cat sitting on a car - CrossAttentionControl.setup_attention_editing(self.inner_model, edited_conditioning, edit_opcodes) - else: - # pass through the attention func but don't act on it - CrossAttentionControl.clear_attention_editing(self.inner_model) def forward(self, x, sigma, uncond, cond, cond_scale): - CrossAttentionControl.clear_requests(self.inner_model) - - if self.edited_conditioning is None: - # faster batch path - x_twice = torch.cat([x]*2) - sigma_twice = torch.cat([sigma]*2) - both_conditionings = torch.cat([uncond, cond]) - unconditioned_next_x, conditioned_next_x = self.inner_model(x_twice, sigma_twice, cond=both_conditionings).chunk(2) - else: - # slower non-batched path (20% slower on mac MPS) - # We are only interested in using attention maps for conditioned_next_x, but batching them with generation of - # unconditioned_next_x causes attention maps to *also* be saved for the unconditioned_next_x. - # This messes app their application later, due to mismatched shape of dim 0 (16 vs. 8) - # (For the batched invocation the `wrangler` function gets attention tensor with shape[0]=16, - # representing batched uncond + cond, but then when it comes to applying the saved attention, the - # wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.) - # todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well. - unconditioned_next_x = self.inner_model(x, sigma, cond=uncond) - - # process x using the original prompt, saving the attention maps - CrossAttentionControl.request_save_attention_maps(self.inner_model) - _ = self.inner_model(x, sigma, cond=cond) - CrossAttentionControl.clear_requests(self.inner_model) - - # process x again, using the saved attention maps to control where self.edited_conditioning will be applied - CrossAttentionControl.request_apply_saved_attention_maps(self.inner_model) - conditioned_next_x = self.inner_model(x, sigma, cond=self.edited_conditioning) - CrossAttentionControl.clear_requests(self.inner_model) + unconditioned_next_x, conditioned_next_x = self.do_cross_attention_controllable_diffusion_step(x, sigma, uncond, cond, self.inner_model, + model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond)) if self.warmup < self.warmup_max: thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max)) @@ -204,7 +178,7 @@ class KSampler(Sampler): unconditional_guidance_scale=1.0, unconditional_conditioning=None, edited_conditioning=None, - edit_token_index_map=None, + conditioning_edit_opcodes=None, threshold = 0, perlin = 0, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... @@ -236,21 +210,22 @@ class KSampler(Sampler): else: x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] - model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10), - edited_conditioning=edited_conditioning, edit_opcodes=edit_token_index_map) + model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10)) + model_wrap_cfg.prepare_to_sample(S, edited_conditioning=edited_conditioning, conditioning_edit_opcodes=conditioning_edit_opcodes) extra_args = { 'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale, } print(f'>> Sampling with k_{self.schedule} starting at step {len(self.sigmas)-S-1} of {len(self.sigmas)-1} ({S} new sampling steps)') - return ( + sampling_result = ( K.sampling.__dict__[f'sample_{self.schedule}']( model_wrap_cfg, x, sigmas, extra_args=extra_args, callback=route_callback ), None, ) + return sampling_result # this code will support inpainting if and when ksampler API modified or # a workaround is found. @@ -312,7 +287,7 @@ class KSampler(Sampler): else: return x - def prepare_to_sample(self,t_enc): + def prepare_to_sample(self,t_enc,**kwargs): self.t_enc = t_enc self.model_wrap = None self.ds = None @@ -323,4 +298,3 @@ class KSampler(Sampler): Overrides parent method to return the q_sample of the inner model. ''' return self.model.inner_model.q_sample(x0,ts) - diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py index 9e722eb932..eb778813a0 100644 --- a/ldm/models/diffusion/plms.py +++ b/ldm/models/diffusion/plms.py @@ -5,14 +5,24 @@ import numpy as np from tqdm import tqdm from functools import partial from ldm.invoke.devices import choose_torch_device +from ldm.models.diffusion.cross_attention import CrossAttentionControllableDiffusionMixin from ldm.models.diffusion.sampler import Sampler from ldm.modules.diffusionmodules.util import noise_like -class PLMSSampler(Sampler): +class PLMSSampler(Sampler, CrossAttentionControllableDiffusionMixin): def __init__(self, model, schedule='linear', device=None, **kwargs): super().__init__(model,schedule,model.num_timesteps, device) + def prepare_to_sample(self, t_enc, **kwargs): + super().prepare_to_sample(t_enc, **kwargs) + + edited_conditioning = kwargs.get('edited_conditioning', None) + edit_opcodes = kwargs.get('conditioning_edit_opcodes', None) + + self.setup_cross_attention_control_if_appropriate(self.model, edited_conditioning, edit_opcodes) + + # this is the essential routine @torch.no_grad() def p_sample( @@ -41,14 +51,18 @@ class PLMSSampler(Sampler): unconditional_conditioning is None or unconditional_guidance_scale == 1.0 ): + # damian0815 does not think this code path is ever used e_t = self.model.apply_model(x, t, c) else: - x_in = torch.cat([x] * 2) - t_in = torch.cat([t] * 2) - c_in = torch.cat([unconditional_conditioning, c]) - e_t_uncond, e_t = self.model.apply_model( - x_in, t_in, c_in - ).chunk(2) + #x_in = torch.cat([x] * 2) + #t_in = torch.cat([t] * 2) + #c_in = torch.cat([unconditional_conditioning, c]) + #e_t_uncond, e_t = self.model.apply_model( + # x_in, t_in, c_in + #).chunk(2) + e_t_uncond, e_t = self.do_cross_attention_controllable_diffusion_step(x, t, unconditional_conditioning, c, self.model, + model_forward_callback=lambda x, sigma, cond: self.model.apply_model(x, sigma, cond)) + e_t = e_t_uncond + unconditional_guidance_scale * ( e_t - e_t_uncond ) diff --git a/ldm/models/diffusion/sampler.py b/ldm/models/diffusion/sampler.py index eb7caebba0..b8377ebb39 100644 --- a/ldm/models/diffusion/sampler.py +++ b/ldm/models/diffusion/sampler.py @@ -192,6 +192,7 @@ class Sampler(object): unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, steps=S, + **kwargs ) return samples, intermediates @@ -216,6 +217,7 @@ class Sampler(object): unconditional_guidance_scale=1.0, unconditional_conditioning=None, steps=None, + **kwargs ): b = shape[0] time_range = ( @@ -233,7 +235,7 @@ class Sampler(object): dynamic_ncols=True, ) old_eps = [] - self.prepare_to_sample(t_enc=total_steps) + self.prepare_to_sample(t_enc=total_steps,**kwargs) img = self.get_initial_image(x_T,shape,total_steps) # probably don't need this at all @@ -323,7 +325,7 @@ class Sampler(object): iterator = tqdm(time_range, desc='Decoding image', total=total_steps) x_dec = x_latent x0 = init_latent - self.prepare_to_sample(t_enc=total_steps) + self.prepare_to_sample(t_enc=total_steps,**kwargs) for i, step in enumerate(iterator): index = total_steps - i - 1 @@ -414,5 +416,3 @@ class Sampler(object): ''' return self.model.q_sample(x0,ts) - - From d572af2acf66a4c150dc87122379378987be5486 Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Tue, 18 Oct 2022 22:22:47 +0200 Subject: [PATCH 11/54] fix cross-attention on k* samplers --- ldm/models/diffusion/ksampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index 78d5978efe..417458f18f 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -43,7 +43,7 @@ class CFGDenoiser(nn.Module, CrossAttentionControllableDiffusionMixin): edited_conditioning = kwargs.get('edited_conditioning', None) conditioning_edit_opcodes = kwargs.get('conditioning_edit_opcodes', None) - self.setup_cross_attention_control_if_appropriate(self.model, edited_conditioning, conditioning_edit_opcodes) + self.setup_cross_attention_control_if_appropriate(self.inner_model, edited_conditioning, conditioning_edit_opcodes) def forward(self, x, sigma, uncond, cond, cond_scale): From 2b79a716aac7968550382d991a2e5f85bf8ed34e Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Tue, 18 Oct 2022 22:54:51 +0200 Subject: [PATCH 12/54] wip hi-res fix --- ldm/invoke/generator/txt2img2img.py | 8 ++++++-- ldm/models/diffusion/sampler.py | 3 +++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/ldm/invoke/generator/txt2img2img.py b/ldm/invoke/generator/txt2img2img.py index 945ebadd90..afe680ac6e 100644 --- a/ldm/invoke/generator/txt2img2img.py +++ b/ldm/invoke/generator/txt2img2img.py @@ -22,7 +22,7 @@ class Txt2Img2Img(Generator): Return value depends on the seed at the time you call it kwargs are 'width' and 'height' """ - uc, c = conditioning + uc, c, ec, edit_opcodes = conditioning @torch.no_grad() def make_image(x_T): @@ -60,7 +60,9 @@ class Txt2Img2Img(Generator): unconditional_guidance_scale = cfg_scale, unconditional_conditioning = uc, eta = ddim_eta, - img_callback = step_callback + img_callback = step_callback, + edited_conditioning = ec, + conditioning_edit_opcodes = edit_opcodes ) print( @@ -94,6 +96,8 @@ class Txt2Img2Img(Generator): img_callback = step_callback, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc, + edited_conditioning = ec, + conditioning_edit_opcodes = edit_opcodes ) if self.free_gpu_mem: diff --git a/ldm/models/diffusion/sampler.py b/ldm/models/diffusion/sampler.py index b8377ebb39..cd8940fa6e 100644 --- a/ldm/models/diffusion/sampler.py +++ b/ldm/models/diffusion/sampler.py @@ -309,6 +309,9 @@ class Sampler(object): use_original_steps=False, init_latent = None, mask = None, + edited_conditioning = None, + conditioning_edit_opcodes = None, + **kwargs ): timesteps = ( From 582880b314f6467c13d47a740d912966e51f259c Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Tue, 18 Oct 2022 23:23:38 +0200 Subject: [PATCH 13/54] add cross-attention support to im2img; prevent inpainting from crashing --- ldm/generate.py | 3 ++- ldm/invoke/generator/img2img.py | 7 +++++-- ldm/invoke/generator/inpaint.py | 3 ++- ldm/models/diffusion/ksampler.py | 7 ++++++- 4 files changed, 15 insertions(+), 5 deletions(-) diff --git a/ldm/generate.py b/ldm/generate.py index 37df973291..45ed2e73d1 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -541,7 +541,8 @@ class Generate: image = Image.open(image_path) # used by multiple postfixers - uc, c = get_uc_and_c( + # todo: cross-attention + uc, c, _, _ = get_uc_and_c_and_ec( prompt, model =self.model, skip_normalize=opt.skip_normalize, log_tokens =opt.log_tokenization diff --git a/ldm/invoke/generator/img2img.py b/ldm/invoke/generator/img2img.py index 7fde1a94cf..7852591048 100644 --- a/ldm/invoke/generator/img2img.py +++ b/ldm/invoke/generator/img2img.py @@ -32,7 +32,7 @@ class Img2Img(Generator): ) # move to latent space t_enc = int(strength * steps) - uc, c = conditioning + uc, c, ec, edit_opcodes = conditioning def make_image(x_T): # encode (scaled latent) @@ -49,7 +49,10 @@ class Img2Img(Generator): img_callback = step_callback, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc, - init_latent = self.init_latent, # changes how noising is performed in ksampler + init_latent = self.init_latent, + edited_conditioning = ec, + conditioning_edit_opcodes = edit_opcodes + # changes how noising is performed in ksampler ) return self.sample_to_image(samples) diff --git a/ldm/invoke/generator/inpaint.py b/ldm/invoke/generator/inpaint.py index bc4b6133b3..8f01b4ad2d 100644 --- a/ldm/invoke/generator/inpaint.py +++ b/ldm/invoke/generator/inpaint.py @@ -45,7 +45,8 @@ class Inpaint(Img2Img): ) # move to latent space t_enc = int(strength * steps) - uc, c = conditioning + # todo: support cross-attention control + uc, c, _, _ = conditioning print(f">> target t_enc is {t_enc} steps") diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index 417458f18f..a7be70c9ce 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -132,6 +132,7 @@ class KSampler(Sampler): use_original_steps=False, init_latent = None, mask = None, + **kwargs ): samples,_ = self.sample( batch_size = 1, @@ -143,7 +144,8 @@ class KSampler(Sampler): unconditional_conditioning = unconditional_conditioning, img_callback = img_callback, x0 = init_latent, - mask = mask + mask = mask, + **kwargs ) return samples @@ -238,6 +240,8 @@ class KSampler(Sampler): index, unconditional_guidance_scale=1.0, unconditional_conditioning=None, + edited_conditioning=None, + conditioning_edit_opcodes=None, **kwargs, ): if self.model_wrap is None: @@ -263,6 +267,7 @@ class KSampler(Sampler): # so the actual formula for indexing into sigmas: # sigma_index = (steps-index) s_index = t_enc - index - 1 + self.model_wrap.prepare_to_sample(s_index, edited_conditioning=edited_conditioning, conditioning_edit_opcodes=conditioning_edit_opcodes) img = K.sampling.__dict__[f'_{self.schedule}']( self.model_wrap, img, From 824cb201b10cf14d18604d132b730a1b0d6796d5 Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Tue, 18 Oct 2022 23:42:59 +0200 Subject: [PATCH 14/54] pass img2img ddim/plms edited conditioning through kwargs --- ldm/models/diffusion/sampler.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/ldm/models/diffusion/sampler.py b/ldm/models/diffusion/sampler.py index cd8940fa6e..879b85495a 100644 --- a/ldm/models/diffusion/sampler.py +++ b/ldm/models/diffusion/sampler.py @@ -309,8 +309,6 @@ class Sampler(object): use_original_steps=False, init_latent = None, mask = None, - edited_conditioning = None, - conditioning_edit_opcodes = None, **kwargs ): From 147d39cb7c8ad4025e70eb16e607a930c9579157 Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Wed, 19 Oct 2022 18:19:55 +0200 Subject: [PATCH 15/54] wip refactoring shared InvokeAI diffuser mixin to component --- ldm/models/diffusion/ddim.py | 24 +++--- ldm/models/diffusion/ksampler.py | 47 ++++------- ldm/models/diffusion/plms.py | 28 +++---- ...ention.py => shared_invokeai_diffusion.py} | 82 ++++++++++++++----- 4 files changed, 104 insertions(+), 77 deletions(-) rename ldm/models/diffusion/{cross_attention.py => shared_invokeai_diffusion.py} (77%) diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py index 4980b03c42..a1f76c18e2 100644 --- a/ldm/models/diffusion/ddim.py +++ b/ldm/models/diffusion/ddim.py @@ -1,25 +1,32 @@ """SAMPLING ONLY.""" +from typing import Union import torch import numpy as np from tqdm import tqdm from functools import partial from ldm.invoke.devices import choose_torch_device -from ldm.models.diffusion.cross_attention import CrossAttentionControllableDiffusionMixin +from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from ldm.models.diffusion.sampler import Sampler from ldm.modules.diffusionmodules.util import noise_like -class DDIMSampler(Sampler, CrossAttentionControllableDiffusionMixin): +class DDIMSampler(Sampler): def __init__(self, model, schedule='linear', device=None, **kwargs): super().__init__(model,schedule,model.num_timesteps,device) + self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model, + model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond)) + def prepare_to_sample(self, t_enc, **kwargs): super().prepare_to_sample(t_enc, **kwargs) edited_conditioning = kwargs.get('edited_conditioning', None) - edit_opcodes = kwargs.get('conditioning_edit_opcodes', None) - self.setup_cross_attention_control_if_appropriate(self.model, edited_conditioning, edit_opcodes) + if edited_conditioning is not None: + edit_opcodes = kwargs.get('conditioning_edit_opcodes', None) + self.invokeai_diffuser.setup_cross_attention_control(edited_conditioning, edit_opcodes) + else: + self.invokeai_diffuser.cleanup_cross_attention_control() # This is the central routine @@ -27,7 +34,7 @@ class DDIMSampler(Sampler, CrossAttentionControllableDiffusionMixin): def p_sample( self, x, - c, + c: Union[torch.Tensor, list], t, index, repeat_noise=False, @@ -51,12 +58,7 @@ class DDIMSampler(Sampler, CrossAttentionControllableDiffusionMixin): e_t = self.model.apply_model(x, t, c) else: - e_t_uncond, e_t = self.do_cross_attention_controllable_diffusion_step(x, t, unconditional_conditioning, c, self.model, - model_forward_callback=lambda x, sigma, cond: self.model.apply_model(x, sigma, cond)) - - e_t = e_t_uncond + unconditional_guidance_scale * ( - e_t - e_t_uncond - ) + e_t = self.invokeai_diffuser.do_diffusion_step(x, t, unconditional_conditioning, c, unconditional_guidance_scale) if score_corrector is not None: assert self.model.parameterization == 'eps' diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index a7be70c9ce..ea28b0e40b 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -1,19 +1,11 @@ """wrapper around part of Katherine Crowson's k-diffusion library, making it call compatible with other Samplers""" -from enum import Enum import k_diffusion as K import torch -import torch.nn as nn -from ldm.invoke.devices import choose_torch_device -from ldm.models.diffusion.sampler import Sampler -from ldm.util import rand_perlin_2d -from ldm.modules.diffusionmodules.util import ( - make_ddim_sampling_parameters, - make_ddim_timesteps, - noise_like, - extract_into_tensor, -) -from ldm.models.diffusion.cross_attention import CrossAttentionControl, CrossAttentionControllableDiffusionMixin +from torch import nn + +from .sampler import Sampler +from .shared_invokeai_diffusion import InvokeAIDiffuserComponent def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7): @@ -30,27 +22,32 @@ def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7): return torch.clamp(result, min=minval, max=maxval) -class CFGDenoiser(nn.Module, CrossAttentionControllableDiffusionMixin): +class CFGDenoiser(nn.Module): def __init__(self, model, threshold = 0, warmup = 0): super().__init__() self.inner_model = model self.threshold = threshold self.warmup_max = warmup self.warmup = max(warmup / 10, 1) + self.invokeai_diffuser = InvokeAIDiffuserComponent(model, + model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond)) def prepare_to_sample(self, t_enc, **kwargs): edited_conditioning = kwargs.get('edited_conditioning', None) - conditioning_edit_opcodes = kwargs.get('conditioning_edit_opcodes', None) - self.setup_cross_attention_control_if_appropriate(self.inner_model, edited_conditioning, conditioning_edit_opcodes) + if edited_conditioning is not None: + conditioning_edit_opcodes = kwargs.get('conditioning_edit_opcodes', None) + self.invokeai_diffuser.setup_cross_attention_control(edited_conditioning, conditioning_edit_opcodes) + else: + self.invokeai_diffuser.cleanup_cross_attention_control() def forward(self, x, sigma, uncond, cond, cond_scale): - unconditioned_next_x, conditioned_next_x = self.do_cross_attention_controllable_diffusion_step(x, sigma, uncond, cond, self.inner_model, - model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond)) + final_next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale) + # apply threshold if self.warmup < self.warmup_max: thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max)) self.warmup += 1 @@ -58,9 +55,8 @@ class CFGDenoiser(nn.Module, CrossAttentionControllableDiffusionMixin): thresh = self.threshold if thresh > self.threshold: thresh = self.threshold - # to scale how much effect conditioning has, calculate the changes it does and then scale that - scaled_delta = (conditioned_next_x - unconditioned_next_x) * cond_scale - return cfg_apply_threshold(unconditioned_next_x + scaled_delta, thresh) + return cfg_apply_threshold(final_next_x, thresh) + class KSampler(Sampler): @@ -75,16 +71,6 @@ class KSampler(Sampler): self.ds = None self.s_in = None - def forward(self, x, sigma, uncond, cond, cond_scale): - x_in = torch.cat([x] * 2) - sigma_in = torch.cat([sigma] * 2) - cond_in = torch.cat([uncond, cond]) - uncond, cond = self.inner_model( - x_in, sigma_in, cond=cond_in - ).chunk(2) - return uncond + (cond - uncond) * cond_scale - - def make_schedule( self, ddim_num_steps, @@ -303,3 +289,4 @@ class KSampler(Sampler): Overrides parent method to return the q_sample of the inner model. ''' return self.model.inner_model.q_sample(x0,ts) + diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py index eb778813a0..d5c227b9f1 100644 --- a/ldm/models/diffusion/plms.py +++ b/ldm/models/diffusion/plms.py @@ -5,22 +5,28 @@ import numpy as np from tqdm import tqdm from functools import partial from ldm.invoke.devices import choose_torch_device -from ldm.models.diffusion.cross_attention import CrossAttentionControllableDiffusionMixin +from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from ldm.models.diffusion.sampler import Sampler from ldm.modules.diffusionmodules.util import noise_like -class PLMSSampler(Sampler, CrossAttentionControllableDiffusionMixin): +class PLMSSampler(Sampler): def __init__(self, model, schedule='linear', device=None, **kwargs): super().__init__(model,schedule,model.num_timesteps, device) + self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model, + model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond)) + def prepare_to_sample(self, t_enc, **kwargs): super().prepare_to_sample(t_enc, **kwargs) edited_conditioning = kwargs.get('edited_conditioning', None) - edit_opcodes = kwargs.get('conditioning_edit_opcodes', None) - self.setup_cross_attention_control_if_appropriate(self.model, edited_conditioning, edit_opcodes) + if edited_conditioning is not None: + edit_opcodes = kwargs.get('conditioning_edit_opcodes', None) + self.invokeai_diffuser.setup_cross_attention_control(edited_conditioning, edit_opcodes) + else: + self.invokeai_diffuser.cleanup_cross_attention_control() # this is the essential routine @@ -51,21 +57,11 @@ class PLMSSampler(Sampler, CrossAttentionControllableDiffusionMixin): unconditional_conditioning is None or unconditional_guidance_scale == 1.0 ): - # damian0815 does not think this code path is ever used + # damian0815 does not know if this code path is ever used e_t = self.model.apply_model(x, t, c) else: - #x_in = torch.cat([x] * 2) - #t_in = torch.cat([t] * 2) - #c_in = torch.cat([unconditional_conditioning, c]) - #e_t_uncond, e_t = self.model.apply_model( - # x_in, t_in, c_in - #).chunk(2) - e_t_uncond, e_t = self.do_cross_attention_controllable_diffusion_step(x, t, unconditional_conditioning, c, self.model, - model_forward_callback=lambda x, sigma, cond: self.model.apply_model(x, sigma, cond)) - e_t = e_t_uncond + unconditional_guidance_scale * ( - e_t - e_t_uncond - ) + e_t = self.invokeai_diffuser.do_diffusion_step(x, t, unconditional_conditioning, c, unconditional_guidance_scale) if score_corrector is not None: assert self.model.parameterization == 'eps' diff --git a/ldm/models/diffusion/cross_attention.py b/ldm/models/diffusion/shared_invokeai_diffusion.py similarity index 77% rename from ldm/models/diffusion/cross_attention.py rename to ldm/models/diffusion/shared_invokeai_diffusion.py index c0760fff47..e2d6ba5fb6 100644 --- a/ldm/models/diffusion/cross_attention.py +++ b/ldm/models/diffusion/shared_invokeai_diffusion.py @@ -1,33 +1,70 @@ from enum import Enum +from typing import Callable + import torch +class InvokeAIDiffuserComponent: -class CrossAttentionControllableDiffusionMixin: + class Conditioning: + def __init__(self, edited_conditioning: torch.Tensor = None, edit_opcodes: list[tuple] = None): + """ + :param edited_conditioning: if doing cross-attention control, the edited conditioning (1 x 77 x 768) + :param edit_opcodes: if doing cross-attention control, opcodes from a SequenceMatcher describing how to map original conditioning tokens to edited conditioning tokens + """ + #self.conditioning = conditioning + #self.unconditioning = unconditioning + self.edited_conditioning = edited_conditioning + self.edit_opcodes = edit_opcodes - def setup_cross_attention_control_if_appropriate(self, model, edited_conditioning, edit_opcodes): + ''' + The aim of this component is to provide a single place for code that can be applied identically to + all InvokeAI diffusion procedures. + + At the moment it includes the following features: + * Cross Attention Control ("prompt2prompt") + ''' + + def __init__(self, model, model_forward_callback: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]): + """ + :param model: the unet model to pass through to cross attention control + :param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning) + """ + self.model = model + self.model_forward_callback = model_forward_callback + + + def setup_cross_attention_control(self, edited_conditioning, edit_opcodes): self.edited_conditioning = edited_conditioning + CrossAttentionControl.setup_attention_editing(self.model, edited_conditioning, edit_opcodes) - if edited_conditioning is not None: - # a cat sitting on a car - CrossAttentionControl.setup_attention_editing(model, edited_conditioning, edit_opcodes) - else: - # pass through the attention func but don't act on it - CrossAttentionControl.clear_attention_editing(model) + def cleanup_cross_attention_control(self): + self.edited_conditioning = None + CrossAttentionControl.clear_attention_editing(self.model) - def cleanup_cross_attention_control(self, model): - CrossAttentionControl.clear_attention_editing(model) - def do_cross_attention_controllable_diffusion_step(self, x, sigma, unconditioning, conditioning, model, model_forward_callback): + def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor, + unconditioning: torch.Tensor, conditioning: torch.Tensor, + unconditional_guidance_scale: float): + """ + :param x: Current latents + :param sigma: aka t, passed to the internal model to control how much denoising will occur + :param unconditioning: [B x 77 x 768] embeddings for unconditioned output + :param conditioning: [B x 77 x 768] embeddings for conditioned output + :param unconditional_guidance_scale: aka CFG scale, controls how much effect the conditioning tensor has + :param model: the unet model to pass through to cross attention control + :param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning) + :return: the new latents after applying the model to x using unconditioning and CFG-scaled conditioning. + """ - CrossAttentionControl.clear_requests(model) + CrossAttentionControl.clear_requests(self.model) if self.edited_conditioning is None: # faster batched path x_twice = torch.cat([x]*2) sigma_twice = torch.cat([sigma]*2) both_conditionings = torch.cat([unconditioning, conditioning]) - unconditioned_next_x, conditioned_next_x = model_forward_callback(x_twice, sigma_twice, both_conditionings).chunk(2) + unconditioned_next_x, conditioned_next_x = self.model_forward_callback(x_twice, sigma_twice, both_conditionings).chunk(2) else: # slower non-batched path (20% slower on mac MPS) # We are only interested in using attention maps for conditioned_next_x, but batching them with generation of @@ -37,19 +74,24 @@ class CrossAttentionControllableDiffusionMixin: # representing batched uncond + cond, but then when it comes to applying the saved attention, the # wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.) # todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well. - unconditioned_next_x = model_forward_callback(x, sigma, unconditioning) + unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning) # process x using the original prompt, saving the attention maps - CrossAttentionControl.request_save_attention_maps(model) - _ = model_forward_callback(x, sigma, cond=conditioning) - CrossAttentionControl.clear_requests(model) + CrossAttentionControl.request_save_attention_maps(self.model) + _ = self.model_forward_callback(x, sigma, cond=conditioning) + CrossAttentionControl.clear_requests(self.model) # process x again, using the saved attention maps to control where self.edited_conditioning will be applied - CrossAttentionControl.request_apply_saved_attention_maps(model) - conditioned_next_x = model_forward_callback(x, sigma, self.edited_conditioning) + CrossAttentionControl.request_apply_saved_attention_maps(self.model) + conditioned_next_x = self.model_forward_callback(x, sigma, self.edited_conditioning) CrossAttentionControl.clear_requests(model) - return unconditioned_next_x, conditioned_next_x + + # to scale how much effect conditioning has, calculate the changes it does and then scale that + scaled_delta = (conditioned_next_x - unconditioned_next_x) * unconditional_guidance_scale + combined_next_x = unconditioned_next_x + scaled_delta + + return combined_next_x # adapted from bloc97's CrossAttentionControl colab # https://github.com/bloc97/CrossAttentionControl From 1ffd4a9e06728089d14c1fafd332fe8e5759ae30 Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Wed, 19 Oct 2022 19:57:20 +0200 Subject: [PATCH 16/54] refactored single diffusion path seems to be working for all samplers --- ldm/invoke/generator/img2img.py | 5 ++- ldm/invoke/generator/txt2img.py | 6 ++- ldm/invoke/generator/txt2img2img.py | 8 ++-- ldm/models/diffusion/ddim.py | 16 +++----- ldm/models/diffusion/ksampler.py | 23 +++++------ ldm/models/diffusion/plms.py | 12 +++--- .../diffusion/shared_invokeai_diffusion.py | 39 ++++++++++++------- 7 files changed, 57 insertions(+), 52 deletions(-) diff --git a/ldm/invoke/generator/img2img.py b/ldm/invoke/generator/img2img.py index 7852591048..0a12bd90e5 100644 --- a/ldm/invoke/generator/img2img.py +++ b/ldm/invoke/generator/img2img.py @@ -7,6 +7,7 @@ import numpy as np from ldm.invoke.devices import choose_autocast from ldm.invoke.generator.base import Generator from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent class Img2Img(Generator): def __init__(self, model, precision): @@ -33,6 +34,7 @@ class Img2Img(Generator): t_enc = int(strength * steps) uc, c, ec, edit_opcodes = conditioning + structured_conditioning = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes) def make_image(x_T): # encode (scaled latent) @@ -50,8 +52,7 @@ class Img2Img(Generator): unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc, init_latent = self.init_latent, - edited_conditioning = ec, - conditioning_edit_opcodes = edit_opcodes + structured_conditioning = structured_conditioning # changes how noising is performed in ksampler ) diff --git a/ldm/invoke/generator/txt2img.py b/ldm/invoke/generator/txt2img.py index 669f3d81ff..6e158562c5 100644 --- a/ldm/invoke/generator/txt2img.py +++ b/ldm/invoke/generator/txt2img.py @@ -5,6 +5,8 @@ ldm.invoke.generator.txt2img inherits from ldm.invoke.generator import torch import numpy as np from ldm.invoke.generator.base import Generator +from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent + class Txt2Img(Generator): def __init__(self, model, precision): @@ -20,6 +22,7 @@ class Txt2Img(Generator): """ self.perlin = perlin uc, c, ec, edit_opcodes = conditioning + structured_conditioning = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes) @torch.no_grad() def make_image(x_T): @@ -43,8 +46,7 @@ class Txt2Img(Generator): verbose = False, unconditional_guidance_scale = cfg_scale, unconditional_conditioning = uc, - edited_conditioning = ec, - conditioning_edit_opcodes = edit_opcodes, + structured_conditioning = structured_conditioning, eta = ddim_eta, img_callback = step_callback, threshold = threshold, diff --git a/ldm/invoke/generator/txt2img2img.py b/ldm/invoke/generator/txt2img2img.py index afe680ac6e..52a14aae74 100644 --- a/ldm/invoke/generator/txt2img2img.py +++ b/ldm/invoke/generator/txt2img2img.py @@ -7,6 +7,7 @@ import numpy as np import math from ldm.invoke.generator.base import Generator from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent class Txt2Img2Img(Generator): @@ -23,6 +24,7 @@ class Txt2Img2Img(Generator): kwargs are 'width' and 'height' """ uc, c, ec, edit_opcodes = conditioning + structured_conditioning = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes) @torch.no_grad() def make_image(x_T): @@ -61,8 +63,7 @@ class Txt2Img2Img(Generator): unconditional_conditioning = uc, eta = ddim_eta, img_callback = step_callback, - edited_conditioning = ec, - conditioning_edit_opcodes = edit_opcodes + structured_conditioning = structured_conditioning ) print( @@ -96,8 +97,7 @@ class Txt2Img2Img(Generator): img_callback = step_callback, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc, - edited_conditioning = ec, - conditioning_edit_opcodes = edit_opcodes + structured_conditioning = structured_conditioning ) if self.free_gpu_mem: diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py index a1f76c18e2..0ab6911247 100644 --- a/ldm/models/diffusion/ddim.py +++ b/ldm/models/diffusion/ddim.py @@ -2,10 +2,6 @@ from typing import Union import torch -import numpy as np -from tqdm import tqdm -from functools import partial -from ldm.invoke.devices import choose_torch_device from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from ldm.models.diffusion.sampler import Sampler from ldm.modules.diffusionmodules.util import noise_like @@ -20,13 +16,12 @@ class DDIMSampler(Sampler): def prepare_to_sample(self, t_enc, **kwargs): super().prepare_to_sample(t_enc, **kwargs) - edited_conditioning = kwargs.get('edited_conditioning', None) + structured_conditioning = kwargs.get('structured_conditioning', None) - if edited_conditioning is not None: - edit_opcodes = kwargs.get('conditioning_edit_opcodes', None) - self.invokeai_diffuser.setup_cross_attention_control(edited_conditioning, edit_opcodes) + if structured_conditioning is not None and structured_conditioning.wants_cross_attention_control: + self.invokeai_diffuser.setup_cross_attention_control(structured_conditioning) else: - self.invokeai_diffuser.cleanup_cross_attention_control() + self.invokeai_diffuser.remove_cross_attention_control() # This is the central routine @@ -54,10 +49,9 @@ class DDIMSampler(Sampler): unconditional_conditioning is None or unconditional_guidance_scale == 1.0 ): - # damian0815 does not think this code path is ever used + # damian0815 would like to know when/if this code path is used e_t = self.model.apply_model(x, t, c) else: - e_t = self.invokeai_diffuser.do_diffusion_step(x, t, unconditional_conditioning, c, unconditional_guidance_scale) if score_corrector is not None: diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index ea28b0e40b..a8291e32c1 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -34,18 +34,17 @@ class CFGDenoiser(nn.Module): def prepare_to_sample(self, t_enc, **kwargs): - edited_conditioning = kwargs.get('edited_conditioning', None) + structured_conditioning = kwargs.get('structured_conditioning', None) - if edited_conditioning is not None: - conditioning_edit_opcodes = kwargs.get('conditioning_edit_opcodes', None) - self.invokeai_diffuser.setup_cross_attention_control(edited_conditioning, conditioning_edit_opcodes) + if structured_conditioning is not None and structured_conditioning.wants_cross_attention_control: + self.invokeai_diffuser.setup_cross_attention_control(structured_conditioning) else: - self.invokeai_diffuser.cleanup_cross_attention_control() + self.invokeai_diffuser.remove_cross_attention_control() def forward(self, x, sigma, uncond, cond, cond_scale): - final_next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale) + next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale) # apply threshold if self.warmup < self.warmup_max: @@ -55,7 +54,7 @@ class CFGDenoiser(nn.Module): thresh = self.threshold if thresh > self.threshold: thresh = self.threshold - return cfg_apply_threshold(final_next_x, thresh) + return cfg_apply_threshold(next_x, thresh) @@ -165,8 +164,7 @@ class KSampler(Sampler): log_every_t=100, unconditional_guidance_scale=1.0, unconditional_conditioning=None, - edited_conditioning=None, - conditioning_edit_opcodes=None, + structured_conditioning=None, threshold = 0, perlin = 0, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... @@ -199,7 +197,7 @@ class KSampler(Sampler): x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10)) - model_wrap_cfg.prepare_to_sample(S, edited_conditioning=edited_conditioning, conditioning_edit_opcodes=conditioning_edit_opcodes) + model_wrap_cfg.prepare_to_sample(S, structured_conditioning=structured_conditioning) extra_args = { 'cond': conditioning, 'uncond': unconditional_conditioning, @@ -226,8 +224,7 @@ class KSampler(Sampler): index, unconditional_guidance_scale=1.0, unconditional_conditioning=None, - edited_conditioning=None, - conditioning_edit_opcodes=None, + structured_conditioning=None, **kwargs, ): if self.model_wrap is None: @@ -253,7 +250,7 @@ class KSampler(Sampler): # so the actual formula for indexing into sigmas: # sigma_index = (steps-index) s_index = t_enc - index - 1 - self.model_wrap.prepare_to_sample(s_index, edited_conditioning=edited_conditioning, conditioning_edit_opcodes=conditioning_edit_opcodes) + self.model_wrap.prepare_to_sample(s_index, structured_conditioning=structured_conditioning) img = K.sampling.__dict__[f'_{self.schedule}']( self.model_wrap, img, diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py index d5c227b9f1..98975525ed 100644 --- a/ldm/models/diffusion/plms.py +++ b/ldm/models/diffusion/plms.py @@ -20,13 +20,12 @@ class PLMSSampler(Sampler): def prepare_to_sample(self, t_enc, **kwargs): super().prepare_to_sample(t_enc, **kwargs) - edited_conditioning = kwargs.get('edited_conditioning', None) + structured_conditioning = kwargs.get('structured_conditioning', None) - if edited_conditioning is not None: - edit_opcodes = kwargs.get('conditioning_edit_opcodes', None) - self.invokeai_diffuser.setup_cross_attention_control(edited_conditioning, edit_opcodes) + if structured_conditioning is not None and structured_conditioning.wants_cross_attention_control: + self.invokeai_diffuser.setup_cross_attention_control(structured_conditioning) else: - self.invokeai_diffuser.cleanup_cross_attention_control() + self.invokeai_diffuser.remove_cross_attention_control() # this is the essential routine @@ -57,10 +56,9 @@ class PLMSSampler(Sampler): unconditional_conditioning is None or unconditional_guidance_scale == 1.0 ): - # damian0815 does not know if this code path is ever used + # damian0815 would like to know when/if this code path is used e_t = self.model.apply_model(x, t, c) else: - e_t = self.invokeai_diffuser.do_diffusion_step(x, t, unconditional_conditioning, c, unconditional_guidance_scale) if score_corrector is not None: diff --git a/ldm/models/diffusion/shared_invokeai_diffusion.py b/ldm/models/diffusion/shared_invokeai_diffusion.py index e2d6ba5fb6..d4f059de09 100644 --- a/ldm/models/diffusion/shared_invokeai_diffusion.py +++ b/ldm/models/diffusion/shared_invokeai_diffusion.py @@ -6,17 +6,22 @@ import torch class InvokeAIDiffuserComponent: - class Conditioning: + class StructuredConditioning: def __init__(self, edited_conditioning: torch.Tensor = None, edit_opcodes: list[tuple] = None): """ :param edited_conditioning: if doing cross-attention control, the edited conditioning (1 x 77 x 768) :param edit_opcodes: if doing cross-attention control, opcodes from a SequenceMatcher describing how to map original conditioning tokens to edited conditioning tokens """ + # TODO migrate conditioning and unconditioning here, too #self.conditioning = conditioning #self.unconditioning = unconditioning self.edited_conditioning = edited_conditioning self.edit_opcodes = edit_opcodes + @property + def wants_cross_attention_control(self): + return self.edited_conditioning is not None + ''' The aim of this component is to provide a single place for code that can be applied identically to all InvokeAI diffusion procedures. @@ -34,14 +39,20 @@ class InvokeAIDiffuserComponent: self.model_forward_callback = model_forward_callback - def setup_cross_attention_control(self, edited_conditioning, edit_opcodes): - self.edited_conditioning = edited_conditioning - CrossAttentionControl.setup_attention_editing(self.model, edited_conditioning, edit_opcodes) + def setup_cross_attention_control(self, conditioning: StructuredConditioning): + self.conditioning = conditioning + CrossAttentionControl.setup_cross_attention_control(self.model, conditioning.edited_conditioning, conditioning.edit_opcodes) - def cleanup_cross_attention_control(self): - self.edited_conditioning = None - CrossAttentionControl.clear_attention_editing(self.model) + def remove_cross_attention_control(self): + self.conditioning = None + CrossAttentionControl.remove_cross_attention_control(self.model) + @property + def edited_conditioning(self): + if self.conditioning is None: + return None + else: + return self.conditioning.edited_conditioning def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor, unconditioning: torch.Tensor, conditioning: torch.Tensor, @@ -78,13 +89,13 @@ class InvokeAIDiffuserComponent: # process x using the original prompt, saving the attention maps CrossAttentionControl.request_save_attention_maps(self.model) - _ = self.model_forward_callback(x, sigma, cond=conditioning) + _ = self.model_forward_callback(x, sigma, conditioning) CrossAttentionControl.clear_requests(self.model) # process x again, using the saved attention maps to control where self.edited_conditioning will be applied CrossAttentionControl.request_apply_saved_attention_maps(self.model) conditioned_next_x = self.model_forward_callback(x, sigma, self.edited_conditioning) - CrossAttentionControl.clear_requests(model) + CrossAttentionControl.clear_requests(self.model) # to scale how much effect conditioning has, calculate the changes it does and then scale that @@ -100,14 +111,16 @@ class CrossAttentionControl: @classmethod - def clear_attention_editing(cls, model): + def remove_cross_attention_control(cls, model): cls.remove_attention_function(model) @classmethod - def setup_attention_editing(cls, model, - substitute_conditioning: torch.Tensor, - edit_opcodes: list): + def setup_cross_attention_control(cls, model, + substitute_conditioning: torch.Tensor, + edit_opcodes: list): """ + Inject attention parameters and functions into the passed in model to enable cross attention editing. + :param model: The unet model to inject into. :param substitute_conditioning: The "edited" conditioning vector, [Bx77x768] :param edit_opcodes: Opcodes from difflib.SequenceMatcher describing how the base From c3b992db968984fc3bfa4f1311461dc0177eb8af Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Sat, 15 Oct 2022 23:44:54 +0200 Subject: [PATCH 17/54] Squashed commit of the following: commit 9bb0b5d0036c4dffbb72ce11e097fae4ab63defd Author: Damian at mba Date: Sat Oct 15 23:43:41 2022 +0200 undo local_files_only stuff commit eed93f5d30c34cfccaf7497618ae9af17a5ecfbb Author: Damian at mba Date: Sat Oct 15 23:40:37 2022 +0200 Revert "Merge branch 'development-invoke' into fix-prompts" This reverts commit 7c40892a9f184f7e216f14d14feb0411c5a90e24, reversing changes made to e3f2dd62b0548ca6988818ef058093a4f5b022f2. commit f06d6024e345c69e6d5a91ab5423925a68ee95a7 Author: Damian at mba Date: Thu Oct 13 23:30:16 2022 +0200 more efficiently handle multiple conditioning commit 5efdfcbcd980ce6202ab74e7f90e7415ce7260da Merge: b9c0dc5 ac08bb6 Author: Damian at mba Date: Thu Oct 13 14:51:01 2022 +0200 Merge branch 'optional-disable-karras-schedule' into fix-prompts commit ac08bb6fd25e19a9d35cf6c199e66500fb604af1 Author: Damian at mba Date: Thu Oct 13 14:50:43 2022 +0200 append '*use_model_sigmas*' to prompt string to use model sigmas commit 70d8c05a3ff329409f76204f4af94e55d468ab8b Author: Damian at mba Date: Thu Oct 13 12:12:17 2022 +0200 make karras scheduling switchable commit d60df54f69968e2fb22809c55e23b3c02f37ad63 replaced the model's own scheduling with karras scheduling. this has changed image generation (seems worse now?) this commit wraps the change in a bool. commit b9c0dc5f1a658a0e6c3936000e9ae559e1c7a1db Author: Damian at mba Date: Wed Oct 12 20:16:00 2022 +0200 add test of more complex conjunction commit 9ac0c15cc0d7b5f6df3289d3ad474260972a17be Author: Damian at mba Date: Wed Oct 12 17:18:25 2022 +0200 improve comments commit ad33bce60590b87b2a93e90f16dc9d3e935d04a5 Author: Damian at mba Date: Wed Oct 12 17:04:46 2022 +0200 put back thresholding stuff commit 4852c698a325049834ba0d4b358f07210bc7171a Author: Damian at mba Date: Wed Oct 12 14:25:02 2022 +0200 notes on improving conjunction efficiency commit a53bb1e5b68025d09642b935ae6a9a015cfaf2d6 Author: Damian at mba Date: Wed Oct 12 14:14:33 2022 +0200 optional weights support for Conjunction commit fec79ab15e4f0c84dd61cb1b45a5e6a72ae4aaeb Author: Damian at mba Date: Wed Oct 12 12:07:27 2022 +0200 fix blend error and log parsing output commit 1f751c2a039f9c97af57b18e0f019512631d5a25 Author: Damian at mba Date: Wed Oct 12 10:33:33 2022 +0200 fix broken euler sampler commit 02f8148d17efe4b6bde8d29b827092a0626363ee Author: Damian at mba Date: Wed Oct 12 10:24:20 2022 +0200 cleanup prompt parser commit 8028d49ae6c16c0d6ec9c9de9c12d56c32201421 Author: Damian at mba Date: Wed Oct 12 10:14:18 2022 +0200 explicit conjunction, improve flattening logic commit 8a1710892185f07eb77483f7edae0fc4d6bbb250 Author: Damian at mba Date: Tue Oct 11 22:59:30 2022 +0200 adapt multi-conditioning to also work with ddim commit 53802a839850d0d1ff017c6bafe457c4bed750b0 Author: Damian at mba Date: Tue Oct 11 22:31:42 2022 +0200 unconditioning is also fancy-prompt-syntaxable commit 7c40892a9f184f7e216f14d14feb0411c5a90e24 Merge: e3f2dd6 dbe0da4 Author: Damian at mba Date: Tue Oct 11 21:39:54 2022 +0200 Merge branch 'development-invoke' into fix-prompts commit e3f2dd62b0548ca6988818ef058093a4f5b022f2 Merge: eef0e48 06f542e Author: Damian at mba Date: Tue Oct 11 21:38:09 2022 +0200 Merge remote-tracking branch 'upstream/development' into fix-prompts commit eef0e484c2eaa1bd4e0e0b1d3f8d7bba38478144 Author: Damian at mba Date: Tue Oct 11 21:26:25 2022 +0200 fix run-on paren-less attention, add some comments commit fd29afdf0e9f5e0cdc60239e22480c36ca0aaeca Author: Damian at mba Date: Tue Oct 11 21:03:02 2022 +0200 python 3.9 compatibility commit 26f7646eef7f39bc8f7ce805e747df0f723464da Author: Damian at mba Date: Tue Oct 11 20:58:42 2022 +0200 first pass connecting PromptParser to conditioning commit ae53dff3796d7b9a5e7ed30fa1edb0374af6cd8d Author: Damian at mba Date: Tue Oct 11 20:51:15 2022 +0200 update frontend dist commit 9be4a59a2d76f49e635474b5984bfca826a5dab4 Author: Damian at mba Date: Tue Oct 11 19:01:39 2022 +0200 fix issues with correctness checking FlattenedPrompt commit 3be212323eab68e72a363a654124edd9809e4cf0 Author: Damian at mba Date: Tue Oct 11 18:43:16 2022 +0200 parsing nested seems to work pretty ok commit acd73eb08cf67c27cac8a22934754321256f56a9 Author: Damian at mba Date: Tue Oct 11 18:26:17 2022 +0200 wip introducing FlattenedPrompt class commit 71698d5c7c2ac855b690d8ef67e8830148c59eda Author: Damian at mba Date: Tue Oct 11 15:59:42 2022 +0200 recursive attention weighting seems to actually work commit a4e1ec6b20deb7cc0cd12737bdbd266e56144709 Author: Damian at mba Date: Tue Oct 11 15:06:24 2022 +0200 now apparently almost supported nested attention commit da76fd1ddf22a3888cdc08fd4fed38d8b178e524 Author: Damian at mba Date: Tue Oct 11 13:23:37 2022 +0200 wip prompt parsing commit dbe0da4572c2ac22f26a7afd722349a5680a9e47 Author: Kyle Schouviller Date: Mon Oct 10 22:32:35 2022 -0700 Adding node-based invocation apps commit 8f2a2ffc083366de74d7dae471b50b6f98a7c5f8 Author: Damian at mba Date: Mon Oct 10 19:03:18 2022 +0200 fix merge issues commit 73118dee2a8f4891700756e014caf1c9ca629267 Merge: fd00844 12413b0 Author: Damian at mba Date: Mon Oct 10 12:42:48 2022 +0200 Merge remote-tracking branch 'upstream/development' into fix-prompts commit fd0084413541013c2cf71e006af0392719bef53d Author: Damian at mba Date: Mon Oct 10 12:39:38 2022 +0200 wip prompt parsing commit 0be9363db9307859d2b65cffc6af01f57d7873a4 Author: Damian at mba Date: Mon Oct 10 03:20:06 2022 +0200 better +/- attention parsing commit 5383f691874a58ab01cda1e4fac6cf330146526a Author: Damian at mba Date: Mon Oct 10 02:27:47 2022 +0200 prompt parser seems to work commit 591d098a33ce35462428d8c169501d8ed73615ab Author: Damian at mba Date: Sun Oct 9 20:25:37 2022 +0200 supports weighting unconditioning, cross-attention with | commit 7a7220563aa05a2980235b5b908362f66b728309 Author: Damian at mba Date: Sun Oct 9 18:15:56 2022 +0200 i think cross attention might be working? commit 951ed391e7126bff228c18b2db304ad28d59644a Author: Damian at mba Date: Sun Oct 9 16:04:54 2022 +0200 weighted CFG denoiser working with a single item commit ee532a0c2827368c9e45a6a5f3975666402873da Author: Damian at mba Date: Sun Oct 9 06:33:40 2022 +0200 wip probably doesn't work or compile commit 14654bcbd207b9ca28a6cbd37dbd967d699b062d Author: Damian at mba Date: Fri Oct 7 18:11:48 2022 +0200 use tan() to calculate embedding weight for <1 attentions commit 1a8e76b31aa5abf5150419ebf3b29d4658d07f2b Author: Damian at mba Date: Fri Oct 7 16:14:54 2022 +0200 fix bad math.max reference commit f697ff896875876ccaa1e5527405bdaa7ed27cde Author: Damian at mba Date: Fri Oct 7 15:55:57 2022 +0200 respect http[s]x protocol when making socket.io middleware commit 41d3dd4eeae8d4efb05dfb44fc6d8aac5dc468ab Author: Damian at mba Date: Fri Oct 7 13:29:54 2022 +0200 fractional weighting works, by blending with prompts excluding the word commit 087fb6dfb3e8f5e84de8c911f75faa3e3fa3553c Author: Damian at mba Date: Fri Oct 7 10:52:03 2022 +0200 wip doing weights <1 by averaging with conditioning absent the lower-weighted fragment commit 3c49e3f3ec7c18dc60f3e18ed2f7f0d97aad3a47 Author: Damian at mba Date: Fri Oct 7 10:36:15 2022 +0200 notate CFGDenoiser, perhaps commit d2bcf1bb522026ebf209ad0103f6b370383e5070 Author: Damian at mba Date: Thu Oct 6 05:04:47 2022 +0200 hack blending syntax to test attention weighting more extensively commit 94904ef2cf917f74ec23ef7a570e12ff8255b048 Author: Damian at mba Date: Thu Oct 6 04:56:37 2022 +0200 conditioning works, apparently commit 7c6663ddd70f665fd1308b6dd74f92ca393a8df5 Author: Damian at mba Date: Thu Oct 6 02:20:24 2022 +0200 attention weighting, definitely works in positive direction commit 5856d453a9b020bc1a28ff643ae1f58c12c9be73 Author: Damian at mba Date: Tue Oct 4 19:02:14 2022 +0200 wip bubbling weights down commit a2ed14fd9b7d3cb36b6c5348018b364c76d1e892 Author: Damian at mba Date: Tue Oct 4 17:35:39 2022 +0200 bring in changes from PC --- backend/server.py | 2 +- configs/stable-diffusion/v1-inference.yaml | 2 +- ldm/generate.py | 6 +- ldm/invoke/conditioning.py | 82 +++-- ldm/invoke/prompt_parser.py | 331 +++++++++++++++++++++ ldm/models/diffusion/ddim.py | 18 +- ldm/models/diffusion/ddpm.py | 4 +- ldm/models/diffusion/ksampler.py | 29 +- ldm/models/diffusion/sampler.py | 53 ++++ ldm/modules/encoders/modules.py | 203 +++++++++++++ tests/test_prompt_parser.py | 136 +++++++++ 11 files changed, 823 insertions(+), 43 deletions(-) create mode 100644 ldm/invoke/prompt_parser.py create mode 100644 tests/test_prompt_parser.py diff --git a/backend/server.py b/backend/server.py index 7b8a8a5a69..f14c141e12 100644 --- a/backend/server.py +++ b/backend/server.py @@ -527,7 +527,7 @@ def parameters_to_generated_image_metadata(parameters): rfc_dict["sampler"] = parameters["sampler_name"] # display weighted subprompts (liable to change) - subprompts = split_weighted_subprompts(parameters["prompt"]) + subprompts = split_weighted_subprompts(parameters["prompt"], skip_normalize=True) subprompts = [{"prompt": x[0], "weight": x[1]} for x in subprompts] rfc_dict["prompt"] = subprompts diff --git a/configs/stable-diffusion/v1-inference.yaml b/configs/stable-diffusion/v1-inference.yaml index 9c773077b6..baf91f6e26 100644 --- a/configs/stable-diffusion/v1-inference.yaml +++ b/configs/stable-diffusion/v1-inference.yaml @@ -76,4 +76,4 @@ model: target: torch.nn.Identity cond_stage_config: - target: ldm.modules.encoders.modules.FrozenCLIPEmbedder + target: ldm.modules.encoders.modules.WeightedFrozenCLIPEmbedder diff --git a/ldm/generate.py b/ldm/generate.py index 7fb68dec0a..965d37a240 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -438,7 +438,7 @@ class Generate: sampler=self.sampler, steps=steps, cfg_scale=cfg_scale, - conditioning=(uc, c), + conditioning=(uc, c), # here change to arrays ddim_eta=ddim_eta, image_callback=image_callback, # called after the final image is generated step_callback=step_callback, # called after each intermediate image is generated @@ -477,6 +477,10 @@ class Generate: print('**Interrupted** Partial results will be returned.') else: raise KeyboardInterrupt + # brute-force fallback + except Exception as e: + print(traceback.format_exc(), file=sys.stderr) + print('>> Could not generate image.') toc = time.time() print('>> Usage stats:') diff --git a/ldm/invoke/conditioning.py b/ldm/invoke/conditioning.py index fedd965a2c..9b67d5040d 100644 --- a/ldm/invoke/conditioning.py +++ b/ldm/invoke/conditioning.py @@ -12,42 +12,76 @@ log_tokenization() print out colour-coded tokens and warn if trunca import re import torch -def get_uc_and_c(prompt, model, log_tokens=False, skip_normalize=False): +from .prompt_parser import PromptParser, Fragment, Attention, Blend, Conjunction, FlattenedPrompt +from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder + + +def get_uc_and_c(prompt_string_uncleaned, model, log_tokens=False, skip_normalize=False): + # Extract Unconditioned Words From Prompt unconditioned_words = '' unconditional_regex = r'\[(.*?)\]' - unconditionals = re.findall(unconditional_regex, prompt) + unconditionals = re.findall(unconditional_regex, prompt_string_uncleaned) if len(unconditionals) > 0: unconditioned_words = ' '.join(unconditionals) # Remove Unconditioned Words From Prompt unconditional_regex_compile = re.compile(unconditional_regex) - clean_prompt = unconditional_regex_compile.sub(' ', prompt) - prompt = re.sub(' +', ' ', clean_prompt) + clean_prompt = unconditional_regex_compile.sub(' ', prompt_string_uncleaned) + prompt_string_cleaned = re.sub(' +', ' ', clean_prompt) + else: + prompt_string_cleaned = prompt_string_uncleaned - uc = model.get_learned_conditioning([unconditioned_words]) + pp = PromptParser() - # get weighted sub-prompts - weighted_subprompts = split_weighted_subprompts( - prompt, skip_normalize - ) + def build_conditioning_list(prompt_string:str): + parsed_conjunction: Conjunction = pp.parse(prompt_string) + print(f"parsed '{prompt_string}' to {parsed_conjunction}") + assert (type(parsed_conjunction) is Conjunction) + + conditioning_list = [] + def make_embeddings_for_flattened_prompt(flattened_prompt: FlattenedPrompt): + if type(flattened_prompt) is not FlattenedPrompt: + raise f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead" + fragments = [x[0] for x in flattened_prompt.children] + attention_weights = [x[1] for x in flattened_prompt.children] + print(fragments, attention_weights) + return model.get_learned_conditioning([fragments], attention_weights=[attention_weights]) + + for part,weight in zip(parsed_conjunction.prompts, parsed_conjunction.weights): + if type(part) is Blend: + blend:Blend = part + embeddings_to_blend = None + for flattened_prompt in blend.prompts: + this_embedding = make_embeddings_for_flattened_prompt(flattened_prompt) + embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat((embeddings_to_blend, this_embedding)) + blended_embeddings = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0), blend.weights, normalize=blend.normalize_weights) + conditioning_list.append((blended_embeddings, weight)) + else: + flattened_prompt: FlattenedPrompt = part + embeddings = make_embeddings_for_flattened_prompt(flattened_prompt) + conditioning_list.append((embeddings, weight)) + + return conditioning_list + + positive_conditioning_list = build_conditioning_list(prompt_string_cleaned) + negative_conditioning_list = build_conditioning_list(unconditioned_words) + + if len(negative_conditioning_list) == 0: + negative_conditioning = model.get_learned_conditioning([['']], attention_weights=[[1]]) + else: + if len(negative_conditioning_list)>1: + print("cannot do conjunctions on unconditioning for now") + negative_conditioning = negative_conditioning_list[0][0] + + #positive_conditioning_list.append((get_blend_prompts_and_weights(prompt), this_weight)) + #print("got empty_conditionining with shape", empty_conditioning.shape, "c[0][0] with shape", positive_conditioning[0][0].shape) + + # "unconditioned" means "the conditioning tensor is empty" + uc = negative_conditioning + c = positive_conditioning_list - if len(weighted_subprompts) > 1: - # i dont know if this is correct.. but it works - c = torch.zeros_like(uc) - # normalize each "sub prompt" and add it - for subprompt, weight in weighted_subprompts: - log_tokenization(subprompt, model, log_tokens, weight) - c = torch.add( - c, - model.get_learned_conditioning([subprompt]), - alpha=weight, - ) - else: # just standard 1 prompt - log_tokenization(prompt, model, log_tokens, 1) - c = model.get_learned_conditioning([prompt]) - uc = model.get_learned_conditioning([unconditioned_words]) return (uc, c) def split_weighted_subprompts(text, skip_normalize=False)->list: diff --git a/ldm/invoke/prompt_parser.py b/ldm/invoke/prompt_parser.py new file mode 100644 index 0000000000..c976918291 --- /dev/null +++ b/ldm/invoke/prompt_parser.py @@ -0,0 +1,331 @@ +import pyparsing +import pyparsing as pp +from pyparsing import original_text_for + + +class Prompt(): + + def __init__(self, parts: list): + for c in parts: + allowed_types = [Fragment, Attention, CFGScale] + if type(c) not in allowed_types: + raise PromptParser.ParsingException(f"Prompt cannot contain {type(c)}, only {allowed_types} are allowed") + self.children = parts + def __repr__(self): + return f"Prompt:{self.children}" + def __eq__(self, other): + return type(other) is Prompt and other.children == self.children + +class FlattenedPrompt(): + def __init__(self, parts: list): + # verify type correctness + for c in parts: + if type(c) is not tuple: + raise PromptParser.ParsingException( + f"FlattenedPrompt cannot contain {type(c)}, only ('text', weight) tuples are allowed") + text = c[0] + weight = c[1] + if type(text) is not str: + raise PromptParser.ParsingException(f"FlattenedPrompt cannot contain {type(c)}, only ('text', weight) tuples are allowed") + if type(weight) is not float and type(weight) is not int: + raise PromptParser.ParsingException( + f"FlattenedPrompt cannot contain {type(c)}, only ('text', weight) tuples are allowed") + # all looks good + self.children = parts + + def __repr__(self): + return f"FlattenedPrompt:{self.children}" + def __eq__(self, other): + return type(other) is FlattenedPrompt and other.children == self.children + + +class Attention(): + + def __init__(self, weight: float, children: list): + self.weight = weight + self.children = children + #print(f"A: requested attention '{children}' to {weight}") + + def __repr__(self): + return f"Attention:'{self.children}' @ {self.weight}" + def __eq__(self, other): + return type(other) is Attention and other.weight == self.weight and other.fragment == self.fragment + + +class CFGScale(): + def __init__(self, scale_factor: float, fragment: str): + self.fragment = fragment + self.scale_factor = scale_factor + #print(f"S: requested CFGScale '{fragment}' x {scale_factor}") + + def __repr__(self): + return f"CFGScale:'{self.fragment}' x {self.scale_factor}" + def __eq__(self, other): + return type(other) is CFGScale and other.scale_factor == self.scale_factor and other.fragment == self.fragment + + + +class Fragment(): + def __init__(self, text: str): + assert(type(text) is str) + self.text = text + + def __repr__(self): + return "Fragment:'"+self.text+"'" + def __eq__(self, other): + return type(other) is Fragment and other.text == self.text + +class Conjunction(): + def __init__(self, prompts: list, weights: list = None): + # force everything to be a Prompt + #print("making conjunction with", parts) + self.prompts = [x if (type(x) is Prompt or type(x) is Blend or type(x) is FlattenedPrompt) + else Prompt(x) for x in prompts] + self.weights = [1.0]*len(self.prompts) if weights is None else list(weights) + if len(self.weights) != len(self.prompts): + raise PromptParser.ParsingException(f"while parsing Conjunction: mismatched parts/weights counts {prompts}, {weights}") + self.type = 'AND' + + def __repr__(self): + return f"Conjunction:{self.prompts} | weights {self.weights}" + def __eq__(self, other): + return type(other) is Conjunction \ + and other.prompts == self.prompts \ + and other.weights == self.weights + + +class Blend(): + def __init__(self, prompts: list, weights: list[float], normalize_weights: bool=True): + #print("making Blend with prompts", prompts, "and weights", weights) + if len(prompts) != len(weights): + raise PromptParser.ParsingException(f"while parsing Blend: mismatched prompts/weights counts {prompts}, {weights}") + for c in prompts: + if type(c) is not Prompt and type(c) is not FlattenedPrompt: + raise(PromptParser.ParsingException(f"{type(c)} cannot be added to a Blend, only Prompts or FlattenedPrompts")) + # upcast all lists to Prompt objects + self.prompts = [x if (type(x) is Prompt or type(x) is FlattenedPrompt) + else Prompt(x) for x in prompts] + self.prompts = prompts + self.weights = weights + self.normalize_weights = normalize_weights + + def __repr__(self): + return f"Blend:{self.prompts} | weights {self.weights}" + def __eq__(self, other): + return other.__repr__() == self.__repr__() + + +class PromptParser(): + + class ParsingException(Exception): + pass + + def __init__(self, attention_plus_base=1.1, attention_minus_base=0.9): + + self.attention_plus_base = attention_plus_base + self.attention_minus_base = attention_minus_base + + self.root = self.build_parser_logic() + + + def parse(self, prompt: str) -> [list]: + ''' + :param prompt: The prompt string to parse + :return: a tuple + ''' + #print(f"!!parsing '{prompt}'") + + if len(prompt.strip()) == 0: + return Conjunction(prompts=[FlattenedPrompt([('', 1.0)])], weights=[1.0]) + + root = self.root.parse_string(prompt) + #print(f"'{prompt}' parsed to root", root) + #fused = fuse_fragments(parts) + #print("fused to", fused) + + return self.flatten(root[0]) + + def flatten(self, root: Conjunction): + + def fuse_fragments(items): + # print("fusing fragments in ", items) + result = [] + for x in items: + last_weight = result[-1][1] if len(result) > 0 else None + this_text = x[0] + this_weight = x[1] + if last_weight is not None and last_weight == this_weight: + last_text = result[-1][0] + result[-1] = (last_text + ' ' + this_text, last_weight) + else: + result.append(x) + return result + + def flatten_internal(node, weight_scale, results, prefix): + #print(prefix + "flattening", node, "...") + if type(node) is pp.ParseResults: + for x in node: + results = flatten_internal(x, weight_scale, results, prefix+'pr') + #print(prefix, " ParseResults expanded, results is now", results) + elif type(node) is Fragment: + results.append((node.text, float(weight_scale))) + elif type(node) is Attention: + #if node.weight < 1: + # todo: inject a blend when flattening attention with weight <1" + for c in node.children: + results = flatten_internal(c, weight_scale*node.weight, results, prefix+' ') + elif type(node) is Blend: + flattened_subprompts = [] + #print(" flattening blend with prompts", node.prompts, "weights", node.weights) + for prompt in node.prompts: + # prompt is a list + flattened_subprompts = flatten_internal(prompt, weight_scale, flattened_subprompts, prefix+'B ') + results += [Blend(prompts=flattened_subprompts, weights=node.weights)] + elif type(node) is Prompt: + #print(prefix + "about to flatten Prompt with children", node.children) + flattened_prompt = [] + for child in node.children: + flattened_prompt = flatten_internal(child, weight_scale, flattened_prompt, prefix+'P ') + results += [FlattenedPrompt(parts=fuse_fragments(flattened_prompt))] + #print(prefix + "after flattening Prompt, results is", results) + else: + raise PromptParser.ParsingException(f"unhandled node type {type(node)} when flattening {node}") + #print(prefix + "-> after flattening", type(node), "results is", results) + return results + + #print("flattening", root) + + flattened_parts = [] + for part in root.prompts: + flattened_parts += flatten_internal(part, 1.0, [], ' C| ') + weights = root.weights + return Conjunction(flattened_parts, weights) + + + + def build_parser_logic(self): + + lparen = pp.Literal("(").suppress() + rparen = pp.Literal(")").suppress() + # accepts int or float notation, always maps to float + number = pyparsing.pyparsing_common.real | pp.Word(pp.nums).set_parse_action(pp.token_map(float)) + SPACE_CHARS = ' \t\n' + + prompt_part = pp.Forward() + word = pp.Forward() + + def make_fragment(x): + #print("### making fragment for", x) + if type(x) is str: + return Fragment(x) + elif type(x) is pp.ParseResults or type(x) is list: + return Fragment(' '.join([s for s in x])) + else: + raise PromptParser.ParsingException("Cannot make fragment from " + str(x)) + + # attention control of the form +(phrase) / -(phrase) / (phrase) + # phrase can be multiple words, can have multiple +/- signs to increase the effect or type a floating point or integer weight + attention = pp.Forward() + attention_head = (number | pp.Word('+') | pp.Word('-'))\ + .set_name("attention_head")\ + .set_debug(False) + fragment_inside_attention = pp.CharsNotIn(SPACE_CHARS+'()')\ + .set_parse_action(make_fragment)\ + .set_name("fragment_inside_attention")\ + .set_debug(False) + attention_with_parens = pp.Forward() + attention_with_parens_body = pp.nested_expr(content=pp.delimited_list((attention_with_parens | fragment_inside_attention), delim=SPACE_CHARS)) + attention_with_parens << (attention_head + attention_with_parens_body) + + def make_attention(x): + # print("making Attention from parsing with args", x0, x1) + weight = 1 + # number(str) + if type(x[0]) is float or type(x[0]) is int: + weight = float(x[0]) + # +(str) or -(str) or +str or -str + elif type(x[0]) is str: + base = self.attention_plus_base if x[0][0] == '+' else self.attention_minus_base + weight = pow(base, len(x[0])) + # print("Making attention with children of type", [str(type(x)) for x in x1]) + return Attention(weight=weight, children=x[1]) + + attention_with_parens.set_parse_action(make_attention)\ + .set_name("attention_with_parens")\ + .set_debug(False) + + # attention control of the form ++word --word (no parens) + attention_without_parens = ( + (pp.Word('+') | pp.Word('-')) + + pp.CharsNotIn(SPACE_CHARS+'()').set_parse_action(lambda x: [[make_fragment(x)]]) + )\ + .set_name("attention_without_parens")\ + .set_debug(False) + attention_without_parens.set_parse_action(make_attention) + + attention << (attention_with_parens | attention_without_parens)\ + .set_name("attention")\ + .set_debug(False) + + # fragments of text with no attention control + word << pp.Word(pp.printables).set_parse_action(lambda x: Fragment(' '.join([s for s in x]))) + word.set_name("word") + word.set_debug(False) + prompt_part << (attention | word) + prompt_part.set_debug(False) + prompt_part.set_name("prompt_part") + + # root prompt definition + prompt = pp.Group(pp.OneOrMore(prompt_part))\ + .set_parse_action(lambda x: Prompt(x[0])) + + # weighted blend of prompts + # ("promptA", "promptB").blend(a, b) where "promptA" and "promptB" are valid prompts and a and b are float or + # int weights. + # can specify more terms eg ("promptA", "promptB", "promptC").blend(a,b,c) + + def make_prompt_from_quoted_string(x): + #print(' got quoted prompt', x) + + x_unquoted = x[0][1:-1] + if len(x_unquoted.strip()) == 0: + # print(' b : just an empty string') + return Prompt([Fragment('')]) + # print(' b parsing ', c_unquoted) + x_parsed = prompt.parse_string(x_unquoted) + #print(" quoted prompt was parsed to", type(x_parsed),":", x_parsed) + return x_parsed[0] + + quoted_prompt = pp.dbl_quoted_string.set_parse_action(make_prompt_from_quoted_string) + quoted_prompt.set_name('quoted_prompt') + + blend_terms = pp.delimited_list(quoted_prompt).set_name('blend_terms') + blend_weights = pp.delimited_list(number).set_name('blend_weights') + blend = pp.Group(lparen + pp.Group(blend_terms) + rparen + + pp.Literal(".blend").suppress() + + lparen + pp.Group(blend_weights) + rparen).set_name('blend') + blend.set_debug(False) + + + blend.set_parse_action(lambda x: Blend(prompts=x[0][0], weights=x[0][1])) + + conjunction_terms = blend_terms.copy().set_name('conjunction_terms') + conjunction_weights = blend_weights.copy().set_name('conjunction_weights') + conjunction_with_parens_and_quotes = pp.Group(lparen + pp.Group(conjunction_terms) + rparen + + pp.Literal(".and").suppress() + + lparen + pp.Optional(pp.Group(conjunction_weights)) + rparen).set_name('conjunction') + def make_conjunction(x): + parts_raw = x[0][0] + weights = x[0][1] if len(x[0])>1 else [1.0]*len(parts_raw) + parts = [part for part in parts_raw] + return Conjunction(parts, weights) + conjunction_with_parens_and_quotes.set_parse_action(make_conjunction) + + implicit_conjunction = pp.OneOrMore(blend | prompt) + implicit_conjunction.set_parse_action(lambda x: Conjunction(x)) + + conjunction = conjunction_with_parens_and_quotes | implicit_conjunction + conjunction.set_debug(False) + + # top-level is a conjunction of one or more blends or prompts + return conjunction diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py index f5dada8627..5120d92c48 100644 --- a/ldm/models/diffusion/ddim.py +++ b/ldm/models/diffusion/ddim.py @@ -34,23 +34,21 @@ class DDIMSampler(Sampler): b, *_, device = *x.shape, x.device if ( - unconditional_conditioning is None - or unconditional_guidance_scale == 1.0 + (unconditional_conditioning is None + or unconditional_guidance_scale == 1.0) + and c is not list ): e_t = self.model.apply_model(x, t, c) else: - x_in = torch.cat([x] * 2) - t_in = torch.cat([t] * 2) - c_in = torch.cat([unconditional_conditioning, c]) - e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) - e_t = e_t_uncond + unconditional_guidance_scale * ( - e_t - e_t_uncond - ) + e_t = self.apply_weighted_conditioning_list(x, t, self.model.apply_model, unconditional_conditioning, c, unconditional_guidance_scale) if score_corrector is not None: assert self.model.parameterization == 'eps' + if c is list and len(c)>1: + print("warning: ddim score modifier currently ignores all but the first part of the prompt conjunction, this is probably wrong") + corrector_c = [c[0][0] if c is list else c] e_t = score_corrector.modify_score( - self.model, e_t, x, t, c, **corrector_kwargs + self.model, e_t, x, t, corrector_c, **corrector_kwargs ) alphas = ( diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py index 4b62b5e393..0f55786323 100644 --- a/ldm/models/diffusion/ddpm.py +++ b/ldm/models/diffusion/ddpm.py @@ -820,13 +820,13 @@ class LatentDiffusion(DDPM): ) return self.scale_factor * z - def get_learned_conditioning(self, c): + def get_learned_conditioning(self, c, attention_weights=None): if self.cond_stage_forward is None: if hasattr(self.cond_stage_model, 'encode') and callable( self.cond_stage_model.encode ): c = self.cond_stage_model.encode( - c, embedding_manager=self.embedding_manager + c, embedding_manager=self.embedding_manager, attention_weights=attention_weights ) if isinstance(c, DiagonalGaussianDistribution): c = c.mode() diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index ac0615b30c..4d37c8cf9b 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -38,7 +38,8 @@ class CFGDenoiser(nn.Module): x_in = torch.cat([x] * 2) sigma_in = torch.cat([sigma] * 2) cond_in = torch.cat([uncond, cond]) - uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) + unconditioned_x, conditioned_x = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) + if self.warmup < self.warmup_max: thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max)) self.warmup += 1 @@ -46,7 +47,28 @@ class CFGDenoiser(nn.Module): thresh = self.threshold if thresh > self.threshold: thresh = self.threshold - return cfg_apply_threshold(uncond + (cond - uncond) * cond_scale, thresh) + + # damian0815 thinking out loud notes: + # b + (a - b)*scale + # starting at the output that emerges applying the negative prompt (by default ''), + # (-> this is why the unconditioning feels like hammer) + # move toward the positive prompt by an amount controlled by cond_scale. + return cfg_apply_threshold(unconditioned_x + (conditioned_x - unconditioned_x) * cond_scale, thresh) + + +class ProgrammableCFGDenoiser(CFGDenoiser): + def forward(self, x, sigma, uncond, cond, cond_scale): + forward_lambda = lambda x, t, c: self.inner_model(x, t, cond=c) + x_new = Sampler.apply_weighted_conditioning_list(x, sigma, forward_lambda, uncond, cond, cond_scale) + + if self.warmup < self.warmup_max: + thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max)) + self.warmup += 1 + else: + thresh = self.threshold + if thresh > self.threshold: + thresh = self.threshold + return cfg_apply_threshold(x_new, threshold=thresh) class KSampler(Sampler): @@ -181,7 +203,6 @@ class KSampler(Sampler): ) # sigmas are set up in make_schedule - we take the last steps items - total_steps = len(self.sigmas) sigmas = self.sigmas[-S-1:] # x_T is variation noise. When an init image is provided (in x0) we need to add @@ -194,7 +215,7 @@ class KSampler(Sampler): else: x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] - model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10)) + model_wrap_cfg = ProgrammableCFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10)) extra_args = { 'cond': conditioning, 'uncond': unconditional_conditioning, diff --git a/ldm/models/diffusion/sampler.py b/ldm/models/diffusion/sampler.py index ff705513f8..42704f1175 100644 --- a/ldm/models/diffusion/sampler.py +++ b/ldm/models/diffusion/sampler.py @@ -4,6 +4,8 @@ ldm.models.diffusion.sampler Base class for ldm.models.diffusion.ddim, ldm.models.diffusion.ksampler, etc ''' +from math import ceil + import torch import numpy as np from tqdm import tqdm @@ -411,3 +413,54 @@ class Sampler(object): return self.model.inner_model.q_sample(x0,ts) ''' return self.model.q_sample(x0,ts) + + + @classmethod + def apply_weighted_conditioning_list(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale): + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) # aka sigmas + + deltas = None + uncond_latents = None + weighted_cond_list = c_or_weighted_c_list if type(c_or_weighted_c_list) is list else [(c_or_weighted_c_list, 1)] + + # below is fugly omg + num_actual_conditionings = len(c_or_weighted_c_list) + conditionings = [uc] + [c for c,weight in weighted_cond_list] + weights = [1] + [weight for c,weight in weighted_cond_list] + chunk_count = ceil(len(conditionings)/2) + assert(len(conditionings)>=2, "need at least one uncond and one cond") + deltas = None + for chunk_index in range(chunk_count): + offset = chunk_index*2 + chunk_size = min(2, len(conditionings)-offset) + + if chunk_size == 1: + c_in = conditionings[offset] + latents_a = forward_func(x_in[:-1], t_in[:-1], c_in) + latents_b = None + else: + c_in = torch.cat(conditionings[offset:offset+2]) + latents_a, latents_b = forward_func(x_in, t_in, c_in).chunk(2) + + # first chunk is guaranteed to be 2 entries: uncond_latents + first conditioining + if chunk_index == 0: + uncond_latents = latents_a + deltas = latents_b - uncond_latents + else: + deltas = torch.cat((deltas, latents_a - uncond_latents)) + if latents_b is not None: + deltas = torch.cat((deltas, latents_b - uncond_latents)) + + # merge the weighted deltas together into a single merged delta + per_delta_weights = torch.tensor(weights[1:], dtype=deltas.dtype, device=deltas.device) + normalize = False + if normalize: + per_delta_weights /= torch.sum(per_delta_weights) + reshaped_weights = per_delta_weights.reshape(per_delta_weights.shape + (1, 1, 1)) + deltas_merged = torch.sum(deltas * reshaped_weights, dim=0, keepdim=True) + + # old_return_value = super().forward(x, sigma, uncond, cond, cond_scale) + # assert(0 == len(torch.nonzero(old_return_value - (uncond_latents + deltas_merged * cond_scale)))) + + return uncond_latents + deltas_merged * global_guidance_scale diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index 426fccced3..857a8a8e3e 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -1,3 +1,5 @@ +import math + import torch import torch.nn as nn from functools import partial @@ -454,6 +456,207 @@ class FrozenCLIPEmbedder(AbstractEncoder): def encode(self, text, **kwargs): return self(text, **kwargs) +class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder): + + attention_weights_key = "attention_weights" + + def build_token_list_fragment(self, fragment: str, weight: float) -> (torch.Tensor, torch.Tensor): + batch_encoding = self.tokenizer( + fragment, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding='none', + return_tensors='pt', + ) + return batch_encoding, torch.ones_like(batch_encoding) * weight + + + def forward(self, text: list, **kwargs): + ''' + + :param text: A batch of prompt strings, or, a batch of lists of fragments of prompt strings to which different + weights shall be applied. + :param kwargs: If the keyword arg "attention_weights" is passed, it shall contain a batch of lists of weights + for the prompt fragments. In this case text must contain batches of lists of prompt fragments. + :return: A tensor of shape (B, 77, 768) containing weighted embeddings + ''' + if self.attention_weights_key not in kwargs: + # fallback to base class implementation + return super().forward(text, **kwargs) + + attention_weights = kwargs[self.attention_weights_key] + # self.transformer doesn't like receiving "attention_weights" as an argument + kwargs.pop(self.attention_weights_key) + + batch_z = None + for fragments, weights in zip(text, attention_weights): + + # First, weight tokens in individual fragments by scaling the feature vectors as requested (effectively + # applying a multiplier to the CFG scale on a per-token basis). + # For tokens weighted<1, intuitively we want SD to become not merely *less* interested in the concept + # captured by the fragment but actually *dis*interested in it (a 0.01 interest in "red" is still an active + # interest, however small, in redness; what the user probably intends when they attach the number 0.01 to + # "red" is to tell SD that it should almost completely *ignore* redness). + # To do this, the embedding is lerped away from base_embedding in the direction of an embedding for a prompt + # string from which the low-weighted fragment has been simply removed. The closer the weight is to zero, the + # closer the resulting embedding is to an embedding for a prompt that simply lacks this fragment. + + # handle weights >=1 + tokens, per_token_weights = self.get_tokens_and_weights(fragments, weights) + base_embedding = self.build_weighted_embedding_tensor(tokens, per_token_weights, **kwargs) + + # this is our starting point + embeddings = base_embedding.unsqueeze(0) + per_embedding_weights = [1.0] + + # now handle weights <1 + # Do this by building extra embeddings tensors that lack the words being <1 weighted. These will be lerped + # with the embeddings tensors that have the words, such that if the weight of a word is 0.5, the resulting + # embedding will be exactly half-way between the unweighted prompt and the prompt with the <1 weighted words + # removed. + # eg for "mountain:1 man:0.5", intuitively the "man" should be "half-gone". therefore, append an embedding + # for "mountain" (i.e. without "man") to the already-produced embedding for "mountain man", and weight it + # such that the resulting lerped embedding is exactly half-way between "mountain man" and "mountain". + for index, fragment_weight in enumerate(weights): + if fragment_weight < 1: + fragments_without_this = fragments[:index] + fragments[index+1:] + weights_without_this = weights[:index] + weights[index+1:] + tokens, per_token_weights = self.get_tokens_and_weights(fragments_without_this, weights_without_this) + embedding_without_this = self.build_weighted_embedding_tensor(tokens, per_token_weights, **kwargs) + + embeddings = torch.cat((embeddings, embedding_without_this.unsqueeze(0)), dim=1) + # weight of the embedding *without* this fragment gets *stronger* as its weight approaches 0 + # if fragment_weight = 0, basically we want embedding_without_this to completely overwhelm base_embedding + # therefore: + # fragment_weight = 1: we are at base_z => lerp weight 0 + # fragment_weight = 0.5: we are halfway between base_z and here => lerp weight 1 + # fragment_weight = 0: we're now entirely overriding base_z ==> lerp weight inf + # so let's use tan(), because: + # tan is 0.0 at 0, + # 1.0 at PI/4, and + # inf at PI/2 + # -> tan((1-weight)*PI/2) should give us ideal lerp weights + epsilon = 1e-9 + fragment_weight = max(epsilon, fragment_weight) # inf is bad + embedding_lerp_weight = math.tan((1.0 - fragment_weight) * math.pi / 2) + # todo handle negative weight? + + per_embedding_weights.append(embedding_lerp_weight) + + lerped_embeddings = self.apply_embedding_weights(embeddings, per_embedding_weights, normalize=True).squeeze(0) + + print(f"assembled tokens for '{fragments}' into tensor of shape {lerped_embeddings.shape}") + + # append to batch + batch_z = lerped_embeddings.unsqueeze(0) if batch_z is None else torch.cat((batch_z, lerped_embeddings.unsqueeze(0)), dim=1) + + # should have shape (B, 77, 768) + print(f"assembled all tokens into tensor of shape {batch_z.shape}") + + return batch_z + + @classmethod + def apply_embedding_weights(self, embeddings: torch.Tensor, per_embedding_weights: list[float], normalize:bool) -> torch.Tensor: + per_embedding_weights = torch.tensor(per_embedding_weights, dtype=embeddings.dtype, device=embeddings.device) + if normalize: + per_embedding_weights = per_embedding_weights / torch.sum(per_embedding_weights) + reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1, 1,)) + #reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1,1,)).expand(embeddings.shape) + return torch.sum(embeddings * reshaped_weights, dim=1) + # lerped embeddings has shape (77, 768) + + + def get_tokens_and_weights(self, fragments: list[str], weights: list[float]) -> (torch.Tensor, torch.Tensor): + ''' + + :param fragments: + :param weights: Per-fragment weights (CFG scaling). No need for these to be normalized. They will not be normalized here and that's fine. + :return: + ''' + # empty is meaningful + if len(fragments) == 0 and len(weights) == 0: + fragments = [''] + weights = [1] + item_encodings = self.tokenizer( + fragments, + truncation=True, + max_length=self.max_length, + return_overflowing_tokens=False, + padding='do_not_pad', + return_tensors=None, # just give me a list of ints + )['input_ids'] + all_tokens = [] + per_token_weights = [] + print("all fragments:", fragments, weights) + for index, fragment in enumerate(item_encodings): + weight = weights[index] + print("processing fragment", fragment, weight) + fragment_tokens = item_encodings[index] + print("fragment", fragment, "processed to", fragment_tokens) + # trim bos and eos markers before appending + all_tokens.extend(fragment_tokens[1:-1]) + per_token_weights.extend([weight] * (len(fragment_tokens) - 2)) + + if len(all_tokens) > self.max_length - 2: + print("prompt is too long and has been truncated") + all_tokens = all_tokens[:self.max_length - 2] + + # pad out to a 77-entry array: [eos_token, , eos_token, ..., eos_token] + # (77 = self.max_length) + pad_length = self.max_length - 1 - len(all_tokens) + all_tokens.insert(0, self.tokenizer.bos_token_id) + all_tokens.extend([self.tokenizer.eos_token_id] * pad_length) + per_token_weights.insert(0, 1) + per_token_weights.extend([1] * pad_length) + + all_tokens_tensor = torch.tensor(all_tokens, dtype=torch.long).to(self.device) + per_token_weights_tensor = torch.tensor(per_token_weights, dtype=torch.float32).to(self.device) + print(f"assembled all_tokens_tensor with shape {all_tokens_tensor.shape}") + return all_tokens_tensor, per_token_weights_tensor + + def build_weighted_embedding_tensor(self, tokens: torch.Tensor, per_token_weights: torch.Tensor, weight_delta_from_empty=True, **kwargs) -> torch.Tensor: + ''' + Build a tensor representing the passed-in tokens, each of which has a weight. + :param tokens: A tensor of shape (77) containing token ids (integers) + :param per_token_weights: A tensor of shape (77) containing weights (floats) + :param method: Whether to multiply the whole feature vector for each token or just its distance from an "empty" feature vector + :param kwargs: passed on to self.transformer() + :return: A tensor of shape (1, 77, 768) representing the requested weighted embeddings. + ''' + #print(f"building weighted embedding tensor for {tokens} with weights {per_token_weights}") + z = self.transformer(input_ids=tokens.unsqueeze(0), **kwargs) + batch_weights_expanded = per_token_weights.reshape(per_token_weights.shape + (1,)).expand(z.shape) + + if weight_delta_from_empty: + empty_tokens = self.tokenizer([''] * z.shape[0], + truncation=True, + max_length=self.max_length, + padding='max_length', + return_tensors='pt' + )['input_ids'].to(self.device) + empty_z = self.transformer(input_ids=empty_tokens, **kwargs) + z_delta_from_empty = z - empty_z + weighted_z = empty_z + (z_delta_from_empty * batch_weights_expanded) + + weighted_z_delta_from_empty = (weighted_z-empty_z) + print("weighted z has delta from empty with sum", weighted_z_delta_from_empty.sum().item(), "mean", weighted_z_delta_from_empty.mean().item() ) + + #print("using empty-delta method, first 5 rows:") + #print(weighted_z[:5]) + + return weighted_z + + else: + original_mean = z.mean() + z *= batch_weights_expanded + after_weighting_mean = z.mean() + # correct the mean. not sure if this is right but it's what the automatic1111 fork of SD does + mean_correction_factor = original_mean/after_weighting_mean + z *= mean_correction_factor + return z + class FrozenCLIPTextEmbedder(nn.Module): """ diff --git a/tests/test_prompt_parser.py b/tests/test_prompt_parser.py new file mode 100644 index 0000000000..207475d02e --- /dev/null +++ b/tests/test_prompt_parser.py @@ -0,0 +1,136 @@ +import unittest + +from ldm.invoke.prompt_parser import PromptParser, Blend, Conjunction, FlattenedPrompt + +def parse_prompt(prompt_string): + pp = PromptParser() + #print(f"parsing '{prompt_string}'") + parse_result = pp.parse(prompt_string) + #print(f"-> parsed '{prompt_string}' to {parse_result}") + return parse_result + +class PromptParserTestCase(unittest.TestCase): + + def test_empty(self): + self.assertEqual(Conjunction([FlattenedPrompt([('', 1)])]), parse_prompt('')) + + def test_basic(self): + self.assertEqual(Conjunction([FlattenedPrompt([('fire (flames)', 1)])]), parse_prompt("fire (flames)")) + self.assertEqual(Conjunction([FlattenedPrompt([("fire flames", 1)])]), parse_prompt("fire flames")) + self.assertEqual(Conjunction([FlattenedPrompt([("fire, flames", 1)])]), parse_prompt("fire, flames")) + self.assertEqual(Conjunction([FlattenedPrompt([("fire, flames , fire", 1)])]), parse_prompt("fire, flames , fire")) + + def test_attention(self): + self.assertEqual(Conjunction([FlattenedPrompt([('flames', 0.5)])]), parse_prompt("0.5(flames)")) + self.assertEqual(Conjunction([FlattenedPrompt([('fire flames', 0.5)])]), parse_prompt("0.5(fire flames)")) + self.assertEqual(Conjunction([FlattenedPrompt([('flames', 1.1)])]), parse_prompt("+(flames)")) + self.assertEqual(Conjunction([FlattenedPrompt([('flames', 0.9)])]), parse_prompt("-(flames)")) + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1), ('flames', 0.5)])]), parse_prompt("fire 0.5(flames)")) + self.assertEqual(Conjunction([FlattenedPrompt([('flames', pow(1.1, 2))])]), parse_prompt("++(flames)")) + self.assertEqual(Conjunction([FlattenedPrompt([('flames', pow(0.9, 2))])]), parse_prompt("--(flames)")) + self.assertEqual(Conjunction([FlattenedPrompt([('flowers', pow(0.9, 3)), ('flames', pow(1.1, 3))])]), parse_prompt("---(flowers) +++flames")) + self.assertEqual(Conjunction([FlattenedPrompt([('flowers', pow(0.9, 3)), ('flames', pow(1.1, 3))])]), parse_prompt("---(flowers) +++flames")) + self.assertEqual(Conjunction([FlattenedPrompt([('flowers', pow(0.9, 3)), ('flames+', pow(1.1, 3))])]), + parse_prompt("---(flowers) +++flames+")) + self.assertEqual(Conjunction([FlattenedPrompt([('pretty flowers', 1.1)])]), + parse_prompt("+(pretty flowers)")) + self.assertEqual(Conjunction([FlattenedPrompt([('pretty flowers', 1.1), (', the flames are too hot', 1)])]), + parse_prompt("+(pretty flowers), the flames are too hot")) + + def test_no_parens_attention_runon(self): + self.assertEqual(Conjunction([FlattenedPrompt([('fire', pow(1.1, 2)), ('flames', 1.0)])]), parse_prompt("++fire flames")) + self.assertEqual(Conjunction([FlattenedPrompt([('fire', pow(0.9, 2)), ('flames', 1.0)])]), parse_prompt("--fire flames")) + self.assertEqual(Conjunction([FlattenedPrompt([('flowers', 1.0), ('fire', pow(1.1, 2)), ('flames', 1.0)])]), parse_prompt("flowers ++fire flames")) + self.assertEqual(Conjunction([FlattenedPrompt([('flowers', 1.0), ('fire', pow(0.9, 2)), ('flames', 1.0)])]), parse_prompt("flowers --fire flames")) + + + def test_explicit_conjunction(self): + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])]), parse_prompt('("fire", "flames").and(1,1)')) + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])]), parse_prompt('("fire", "flames").and()')) + self.assertEqual( + Conjunction([FlattenedPrompt([('fire flames', 1.0)]), FlattenedPrompt([('mountain man', 1.0)])]), parse_prompt('("fire flames", "mountain man").and()')) + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 2.0)]), FlattenedPrompt([('flames', 0.9)])]), parse_prompt('("2.0(fire)", "-flames").and()')) + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)]), + FlattenedPrompt([('mountain man', 1.0)])]), parse_prompt('("fire", "flames", "mountain man").and()')) + + def test_conjunction_weights(self): + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])], weights=[2.0,1.0]), parse_prompt('("fire", "flames").and(2,1)')) + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])], weights=[1.0,2.0]), parse_prompt('("fire", "flames").and(1,2)')) + + with self.assertRaises(PromptParser.ParsingException): + parse_prompt('("fire", "flames").and(2)') + parse_prompt('("fire", "flames").and(2,1,2)') + + def test_complex_conjunction(self): + self.assertEqual(Conjunction([FlattenedPrompt([("mountain man", 1.0)]), FlattenedPrompt([("a person with a hat", 1.0), ("riding a bicycle", pow(1.1,2))])], weights=[0.5, 0.5]), + parse_prompt("(\"mountain man\", \"a person with a hat ++(riding a bicycle)\").and(0.5, 0.5)")) + + def test_badly_formed(self): + def make_untouched_prompt(prompt): + return Conjunction([FlattenedPrompt([(prompt, 1.0)])]) + + def assert_if_prompt_string_not_untouched(prompt): + self.assertEqual(make_untouched_prompt(prompt), parse_prompt(prompt)) + + assert_if_prompt_string_not_untouched('a test prompt') + assert_if_prompt_string_not_untouched('a badly (formed test prompt') + assert_if_prompt_string_not_untouched('a badly formed test+ prompt') + assert_if_prompt_string_not_untouched('a badly (formed test+ prompt') + assert_if_prompt_string_not_untouched('a badly (formed test+ )prompt') + assert_if_prompt_string_not_untouched('a badly (formed test+ )prompt') + assert_if_prompt_string_not_untouched('(((a badly (formed test+ )prompt') + assert_if_prompt_string_not_untouched('(a (ba)dly (f)ormed test+ prompt') + self.assertEqual(Conjunction([FlattenedPrompt([('(a (ba)dly (f)ormed test+', 1.0), ('prompt', 1.1)])]), + parse_prompt('(a (ba)dly (f)ormed test+ +prompt')) + self.assertEqual(Conjunction([Blend([FlattenedPrompt([('((a badly (formed test+', 1.0)])], weights=[1.0])]), + parse_prompt('("((a badly (formed test+ ").blend(1.0)')) + + def test_blend(self): + self.assertEqual(Conjunction( + [Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('fire flames', 1.0)])], [0.7, 0.3])]), + parse_prompt("(\"fire\", \"fire flames\").blend(0.7, 0.3)") + ) + self.assertEqual(Conjunction([Blend( + [FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('fire flames', 1.0)]), FlattenedPrompt([('hi', 1.0)])], + [0.7, 0.3, 1.0])]), + parse_prompt("(\"fire\", \"fire flames\", \"hi\").blend(0.7, 0.3, 1.0)") + ) + self.assertEqual(Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), + FlattenedPrompt([('fire flames', 1.0), ('hot', pow(1.1, 2))]), + FlattenedPrompt([('hi', 1.0)])], + weights=[0.7, 0.3, 1.0])]), + parse_prompt("(\"fire\", \"fire flames ++(hot)\", \"hi\").blend(0.7, 0.3, 1.0)") + ) + # blend a single entry is not a failure + self.assertEqual(Conjunction([Blend([FlattenedPrompt([('fire', 1.0)])], [0.7])]), + parse_prompt("(\"fire\").blend(0.7)") + ) + # blend with empty + self.assertEqual( + Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]), + parse_prompt("(\"fire\", \"\").blend(0.7, 1)") + ) + self.assertEqual( + Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]), + parse_prompt("(\"fire\", \" \").blend(0.7, 1)") + ) + self.assertEqual( + Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]), + parse_prompt("(\"fire\", \" \").blend(0.7, 1)") + ) + self.assertEqual( + Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([(',', 1.0)])], [0.7, 1.0])]), + parse_prompt("(\"fire\", \" , \").blend(0.7, 1)") + ) + + + def test_nested(self): + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0), ('flames', 2.0), ('trees', 3.0)])]), + parse_prompt('fire 2.0(flames 1.5(trees))')) + self.assertEqual(Conjunction([Blend(prompts=[FlattenedPrompt([('fire', 1.0), ('flames', 1.2100000000000002)]), + FlattenedPrompt([('mountain', 1.0), ('man', 2.0)])], + weights=[1.0, 1.0])]), + parse_prompt('("fire ++(flames)", "mountain 2(man)").blend(1,1)')) + +if __name__ == '__main__': + unittest.main() From 11d7e6b92f2deab1ad4998aad835920593a1e6d3 Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Sat, 15 Oct 2022 23:58:13 +0200 Subject: [PATCH 18/54] undo unwanted changes --- ldm/generate.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/ldm/generate.py b/ldm/generate.py index 965d37a240..7fb68dec0a 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -438,7 +438,7 @@ class Generate: sampler=self.sampler, steps=steps, cfg_scale=cfg_scale, - conditioning=(uc, c), # here change to arrays + conditioning=(uc, c), ddim_eta=ddim_eta, image_callback=image_callback, # called after the final image is generated step_callback=step_callback, # called after each intermediate image is generated @@ -477,10 +477,6 @@ class Generate: print('**Interrupted** Partial results will be returned.') else: raise KeyboardInterrupt - # brute-force fallback - except Exception as e: - print(traceback.format_exc(), file=sys.stderr) - print('>> Could not generate image.') toc = time.time() print('>> Usage stats:') From c6ae9f117634bbfa5d385e98319a36a7701f6ecb Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Sun, 16 Oct 2022 00:45:38 +0200 Subject: [PATCH 19/54] remove unnecessary assertion --- ldm/models/diffusion/sampler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ldm/models/diffusion/sampler.py b/ldm/models/diffusion/sampler.py index 42704f1175..417d1d4491 100644 --- a/ldm/models/diffusion/sampler.py +++ b/ldm/models/diffusion/sampler.py @@ -429,7 +429,6 @@ class Sampler(object): conditionings = [uc] + [c for c,weight in weighted_cond_list] weights = [1] + [weight for c,weight in weighted_cond_list] chunk_count = ceil(len(conditionings)/2) - assert(len(conditionings)>=2, "need at least one uncond and one cond") deltas = None for chunk_index in range(chunk_count): offset = chunk_index*2 From 61357e4e6eecc43b9d52859eb63b4db4e5ffafc1 Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Sun, 16 Oct 2022 01:53:44 +0200 Subject: [PATCH 20/54] be less verbose when assembling prompt --- ldm/invoke/conditioning.py | 13 +++++++------ ldm/modules/encoders/modules.py | 19 ++++++++++--------- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/ldm/invoke/conditioning.py b/ldm/invoke/conditioning.py index 9b67d5040d..e3190f6ed6 100644 --- a/ldm/invoke/conditioning.py +++ b/ldm/invoke/conditioning.py @@ -35,9 +35,10 @@ def get_uc_and_c(prompt_string_uncleaned, model, log_tokens=False, skip_normaliz pp = PromptParser() - def build_conditioning_list(prompt_string:str): + def build_conditioning_list(prompt_string:str, verbose:bool = False): parsed_conjunction: Conjunction = pp.parse(prompt_string) - print(f"parsed '{prompt_string}' to {parsed_conjunction}") + if verbose: + print(f"parsed '{prompt_string}' to {parsed_conjunction}") assert (type(parsed_conjunction) is Conjunction) conditioning_list = [] @@ -46,7 +47,7 @@ def get_uc_and_c(prompt_string_uncleaned, model, log_tokens=False, skip_normaliz raise f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead" fragments = [x[0] for x in flattened_prompt.children] attention_weights = [x[1] for x in flattened_prompt.children] - print(fragments, attention_weights) + #print(fragments, attention_weights) return model.get_learned_conditioning([fragments], attention_weights=[attention_weights]) for part,weight in zip(parsed_conjunction.prompts, parsed_conjunction.weights): @@ -65,14 +66,14 @@ def get_uc_and_c(prompt_string_uncleaned, model, log_tokens=False, skip_normaliz return conditioning_list - positive_conditioning_list = build_conditioning_list(prompt_string_cleaned) - negative_conditioning_list = build_conditioning_list(unconditioned_words) + positive_conditioning_list = build_conditioning_list(prompt_string_cleaned, verbose=True) + negative_conditioning_list = build_conditioning_list(unconditioned_words, verbose=(len(unconditioned_words)>0) ) if len(negative_conditioning_list) == 0: negative_conditioning = model.get_learned_conditioning([['']], attention_weights=[[1]]) else: if len(negative_conditioning_list)>1: - print("cannot do conjunctions on unconditioning for now") + print("cannot do conjunctions on unconditioning for now, everything except the first prompt will be ignored") negative_conditioning = negative_conditioning_list[0][0] #positive_conditioning_list.append((get_blend_prompts_and_weights(prompt), this_weight)) diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index 857a8a8e3e..fcd0363e80 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -547,13 +547,13 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder): lerped_embeddings = self.apply_embedding_weights(embeddings, per_embedding_weights, normalize=True).squeeze(0) - print(f"assembled tokens for '{fragments}' into tensor of shape {lerped_embeddings.shape}") + #print(f"assembled tokens for '{fragments}' into tensor of shape {lerped_embeddings.shape}") # append to batch batch_z = lerped_embeddings.unsqueeze(0) if batch_z is None else torch.cat((batch_z, lerped_embeddings.unsqueeze(0)), dim=1) # should have shape (B, 77, 768) - print(f"assembled all tokens into tensor of shape {batch_z.shape}") + #print(f"assembled all tokens into tensor of shape {batch_z.shape}") return batch_z @@ -589,18 +589,19 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder): )['input_ids'] all_tokens = [] per_token_weights = [] - print("all fragments:", fragments, weights) + #print("all fragments:", fragments, weights) for index, fragment in enumerate(item_encodings): weight = weights[index] - print("processing fragment", fragment, weight) + #print("processing fragment", fragment, weight) fragment_tokens = item_encodings[index] - print("fragment", fragment, "processed to", fragment_tokens) + #print("fragment", fragment, "processed to", fragment_tokens) # trim bos and eos markers before appending all_tokens.extend(fragment_tokens[1:-1]) per_token_weights.extend([weight] * (len(fragment_tokens) - 2)) - if len(all_tokens) > self.max_length - 2: - print("prompt is too long and has been truncated") + if (len(all_tokens) + 2) > self.max_length: + excess_token_count = (len(all_tokens) + 2) - self.max_length + print(f"prompt is {excess_token_count} token(s) too long and has been truncated") all_tokens = all_tokens[:self.max_length - 2] # pad out to a 77-entry array: [eos_token, , eos_token, ..., eos_token] @@ -613,7 +614,7 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder): all_tokens_tensor = torch.tensor(all_tokens, dtype=torch.long).to(self.device) per_token_weights_tensor = torch.tensor(per_token_weights, dtype=torch.float32).to(self.device) - print(f"assembled all_tokens_tensor with shape {all_tokens_tensor.shape}") + #print(f"assembled all_tokens_tensor with shape {all_tokens_tensor.shape}") return all_tokens_tensor, per_token_weights_tensor def build_weighted_embedding_tensor(self, tokens: torch.Tensor, per_token_weights: torch.Tensor, weight_delta_from_empty=True, **kwargs) -> torch.Tensor: @@ -641,7 +642,7 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder): weighted_z = empty_z + (z_delta_from_empty * batch_weights_expanded) weighted_z_delta_from_empty = (weighted_z-empty_z) - print("weighted z has delta from empty with sum", weighted_z_delta_from_empty.sum().item(), "mean", weighted_z_delta_from_empty.mean().item() ) + #print("weighted z has delta from empty with sum", weighted_z_delta_from_empty.sum().item(), "mean", weighted_z_delta_from_empty.mean().item() ) #print("using empty-delta method, first 5 rows:") #print(weighted_z[:5]) From 42883545f9d9fb308f7eb160a52028e21b6ae4b9 Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Thu, 20 Oct 2022 01:42:04 +0200 Subject: [PATCH 21/54] add prompt language support for cross-attention .swap --- ldm/generate.py | 2 +- ldm/invoke/conditioning.py | 110 ++++++----- ldm/invoke/prompt_parser.py | 326 ++++++++++++++++++++++++++++++++ ldm/models/diffusion/ddpm.py | 8 +- ldm/modules/encoders/modules.py | 16 +- tests/test_prompt_parser.py | 173 +++++++++++++++++ 6 files changed, 585 insertions(+), 50 deletions(-) create mode 100644 ldm/invoke/prompt_parser.py create mode 100644 tests/test_prompt_parser.py diff --git a/ldm/generate.py b/ldm/generate.py index 45ed2e73d1..39bcc28162 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -35,7 +35,7 @@ from ldm.invoke.devices import choose_torch_device, choose_precision from ldm.invoke.conditioning import get_uc_and_c_and_ec from ldm.invoke.model_cache import ModelCache from ldm.invoke.seamless import configure_model_padding -from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale +#from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale def fix_func(orig): if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): diff --git a/ldm/invoke/conditioning.py b/ldm/invoke/conditioning.py index 8c8f5eeb01..b7c8e55e66 100644 --- a/ldm/invoke/conditioning.py +++ b/ldm/invoke/conditioning.py @@ -11,71 +11,93 @@ log_tokenization() print out colour-coded tokens and warn if trunca ''' import re from difflib import SequenceMatcher +from typing import Union import torch -def get_uc_and_c_and_ec(prompt, model, log_tokens=False, skip_normalize=False): +from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \ + CrossAttentionControlledFragment, CrossAttentionControlSubstitute, CrossAttentionControlAppend +from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder + + +def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_normalize=False): + # Extract Unconditioned Words From Prompt unconditioned_words = '' unconditional_regex = r'\[(.*?)\]' - unconditionals = re.findall(unconditional_regex, prompt) + unconditionals = re.findall(unconditional_regex, prompt_string_uncleaned) if len(unconditionals) > 0: unconditioned_words = ' '.join(unconditionals) # Remove Unconditioned Words From Prompt unconditional_regex_compile = re.compile(unconditional_regex) - clean_prompt = unconditional_regex_compile.sub(' ', prompt) - prompt = re.sub(' +', ' ', clean_prompt) + clean_prompt = unconditional_regex_compile.sub(' ', prompt_string_uncleaned) + prompt_string_cleaned = re.sub(' +', ' ', clean_prompt) + else: + prompt_string_cleaned = prompt_string_uncleaned - edited_words = None - edited_regex = r'\{(.*?)\}' - edited = re.findall(edited_regex, prompt) - if len(edited) > 0: - edited_words = ' '.join(edited) - edited_regex_compile = re.compile(edited_regex) - clean_prompt = edited_regex_compile.sub(' ', prompt) - prompt = re.sub(' +', ' ', clean_prompt) + pp = PromptParser() - # get weighted sub-prompts - weighted_subprompts = split_weighted_subprompts( - prompt, skip_normalize - ) + parsed_prompt: Union[FlattenedPrompt, Blend] = pp.parse(prompt_string_cleaned) + parsed_negative_prompt: FlattenedPrompt = pp.parse(unconditioned_words) - ec = None + conditioning = None + edited_conditioning = None edit_opcodes = None - uc, _ = model.get_learned_conditioning([unconditioned_words]) + if parsed_prompt is Blend: + blend: Blend = parsed_prompt + embeddings_to_blend = None + for flattened_prompt in blend.prompts: + this_embedding = make_embeddings_for_flattened_prompt(model, flattened_prompt) + embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat( + (embeddings_to_blend, this_embedding)) + conditioning, _ = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0), + blend.weights, + normalize=blend.normalize_weights) + else: + flattened_prompt: FlattenedPrompt = parsed_prompt + wants_cross_attention_control = any([issubclass(type(x), CrossAttentionControlledFragment) for x in flattened_prompt.children]) + if wants_cross_attention_control: + original_prompt = FlattenedPrompt() + edited_prompt = FlattenedPrompt() + for fragment in flattened_prompt.children: + if type(fragment) is CrossAttentionControlSubstitute: + original_prompt.append(fragment.original_fragment) + edited_prompt.append(fragment.edited_fragment) + elif type(fragment) is CrossAttentionControlAppend: + edited_prompt.append(fragment.fragment) + else: + # regular fragment + original_prompt.append(fragment) + edited_prompt.append(fragment) + original_embeddings, original_tokens = make_embeddings_for_flattened_prompt(model, original_prompt) + edited_embeddings, edited_tokens = make_embeddings_for_flattened_prompt(model, edited_prompt) - if len(weighted_subprompts) > 1: - # i dont know if this is correct.. but it works - c = torch.zeros_like(uc) - # normalize each "sub prompt" and add it - for subprompt, weight in weighted_subprompts: - log_tokenization(subprompt, model, log_tokens, weight) - subprompt_embeddings, _ = model.get_learned_conditioning([subprompt]) - c = torch.add( - c, - subprompt_embeddings, - alpha=weight, - ) - if edited_words is not None: - print("can't do cross-attention control with blends just yet, ignoring edits") - else: # just standard 1 prompt - log_tokenization(prompt, model, log_tokens, 1) - c, c_tokens = model.get_learned_conditioning([prompt]) - if edited_words is not None: - ec, ec_tokens = model.get_learned_conditioning([edited_words]) - edit_opcodes = build_token_edit_opcodes(c_tokens, ec_tokens) + conditioning = original_embeddings + edited_conditioning = edited_embeddings + edit_opcodes = build_token_edit_opcodes(original_tokens, edited_tokens) + else: + conditioning, _ = make_embeddings_for_flattened_prompt(model, flattened_prompt) - return (uc, c, ec, edit_opcodes) + unconditioning = make_embeddings_for_flattened_prompt(parsed_negative_prompt) + return (unconditioning, conditioning, edited_conditioning, edit_opcodes) -def build_token_edit_opcodes(c_tokens, ec_tokens): - tokens = c_tokens.cpu().numpy()[0] - tokens_edit = ec_tokens.cpu().numpy()[0] - opcodes = SequenceMatcher(None, tokens, tokens_edit).get_opcodes() - return opcodes +def build_token_edit_opcodes(original_tokens, edited_tokens): + original_tokens = original_tokens.cpu().numpy()[0] + edited_tokens = edited_tokens.cpu().numpy()[0] + + return SequenceMatcher(None, original_tokens, edited_tokens).get_opcodes() + +def make_embeddings_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt): + if type(flattened_prompt) is not FlattenedPrompt: + raise f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead" + fragments = [x[0] for x in flattened_prompt.children] + embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True) + return embeddings, tokens + def split_weighted_subprompts(text, skip_normalize=False)->list: diff --git a/ldm/invoke/prompt_parser.py b/ldm/invoke/prompt_parser.py new file mode 100644 index 0000000000..9dd0f80ade --- /dev/null +++ b/ldm/invoke/prompt_parser.py @@ -0,0 +1,326 @@ +import pyparsing +import pyparsing as pp +from pyparsing import original_text_for + + +class Prompt(): + + def __init__(self, parts: list): + for c in parts: + if not issubclass(type(c), BaseFragment): + raise PromptParser.ParsingException(f"Prompt cannot contain {type(c)}, only {BaseFragment.__subclasses__()} are allowed") + self.children = parts + def __repr__(self): + return f"Prompt:{self.children}" + def __eq__(self, other): + return type(other) is Prompt and other.children == self.children + +class FlattenedPrompt(): + def __init__(self, parts: list): + # verify type correctness + parts_converted = [] + for part in parts: + if issubclass(type(part), BaseFragment): + parts_converted.append(part) + elif type(part) is tuple: + # upgrade tuples to Fragments + if type(part[0]) is not str or (type(part[1]) is not float and type(part[1]) is not int): + raise PromptParser.ParsingException( + f"FlattenedPrompt cannot contain {part}, only Fragments or (str, float) tuples are allowed") + parts_converted.append(Fragment(part[0], part[1])) + else: + raise PromptParser.ParsingException( + f"FlattenedPrompt cannot contain {part}, only Fragments or (str, float) tuples are allowed") + # all looks good + self.children = parts_converted + + def __repr__(self): + return f"FlattenedPrompt:{self.children}" + def __eq__(self, other): + return type(other) is FlattenedPrompt and other.children == self.children + +# abstract base class for Fragments +class BaseFragment: + pass + +class Fragment(BaseFragment): + def __init__(self, text: str, weight: float=1): + assert(type(text) is str) + self.text = text + self.weight = float(weight) + + def __repr__(self): + return "Fragment:'"+self.text+"'@"+str(self.weight) + def __eq__(self, other): + return type(other) is Fragment \ + and other.text == self.text \ + and other.weight == self.weight + +class CrossAttentionControlledFragment(BaseFragment): + pass + +class CrossAttentionControlSubstitute(CrossAttentionControlledFragment): + def __init__(self, original: Fragment, edited: Fragment): + self.original = original + self.edited = edited + + def __repr__(self): + return f"CrossAttentionControlSubstitute:('{self.original}'->'{self.edited}')" + def __eq__(self, other): + return type(other) is CrossAttentionControlSubstitute \ + and other.original == self.original \ + and other.edited == self.edited + +class CrossAttentionControlAppend(CrossAttentionControlledFragment): + def __init__(self, fragment: Fragment): + self.fragment = fragment + def __repr__(self): + return "CrossAttentionControlAppend:",self.fragment + def __eq__(self, other): + return type(other) is CrossAttentionControlAppend \ + and other.fragment == self.fragment + + + +class Conjunction(): + def __init__(self, prompts: list, weights: list = None): + # force everything to be a Prompt + #print("making conjunction with", parts) + self.prompts = [x if (type(x) is Prompt + or type(x) is Blend + or type(x) is FlattenedPrompt) + else Prompt(x) for x in prompts] + self.weights = [1.0]*len(self.prompts) if weights is None else list(weights) + if len(self.weights) != len(self.prompts): + raise PromptParser.ParsingException(f"while parsing Conjunction: mismatched parts/weights counts {prompts}, {weights}") + self.type = 'AND' + + def __repr__(self): + return f"Conjunction:{self.prompts} | weights {self.weights}" + def __eq__(self, other): + return type(other) is Conjunction \ + and other.prompts == self.prompts \ + and other.weights == self.weights + + +class Blend(): + def __init__(self, prompts: list, weights: list[float], normalize_weights: bool=True): + #print("making Blend with prompts", prompts, "and weights", weights) + if len(prompts) != len(weights): + raise PromptParser.ParsingException(f"while parsing Blend: mismatched prompts/weights counts {prompts}, {weights}") + for c in prompts: + if type(c) is not Prompt and type(c) is not FlattenedPrompt: + raise(PromptParser.ParsingException(f"{type(c)} cannot be added to a Blend, only Prompts or FlattenedPrompts")) + # upcast all lists to Prompt objects + self.prompts = [x if (type(x) is Prompt or type(x) is FlattenedPrompt) + else Prompt(x) for x in prompts] + self.prompts = prompts + self.weights = weights + self.normalize_weights = normalize_weights + + def __repr__(self): + return f"Blend:{self.prompts} | weights {self.weights}" + def __eq__(self, other): + return other.__repr__() == self.__repr__() + + +class PromptParser(): + + class ParsingException(Exception): + pass + + def __init__(self, attention_plus_base=1.1, attention_minus_base=0.9): + + self.attention_plus_base = attention_plus_base + self.attention_minus_base = attention_minus_base + + self.root = self.build_parser_logic() + + + def parse(self, prompt: str) -> [list]: + ''' + :param prompt: The prompt string to parse + :return: a tuple + ''' + #print(f"!!parsing '{prompt}'") + + if len(prompt.strip()) == 0: + return Conjunction(prompts=[FlattenedPrompt([('', 1.0)])], weights=[1.0]) + + root = self.root.parse_string(prompt) + #print(f"'{prompt}' parsed to root", root) + #fused = fuse_fragments(parts) + #print("fused to", fused) + + return self.flatten(root[0]) + + def flatten(self, root: Conjunction): + + def fuse_fragments(items): + # print("fusing fragments in ", items) + result = [] + for x in items: + if issubclass(type(x), CrossAttentionControlledFragment): + result.append(x) + else: + last_weight = result[-1].weight \ + if (len(result) > 0 and not issubclass(type(result[-1]), CrossAttentionControlledFragment)) \ + else None + this_text = x.text + this_weight = x.weight + if last_weight is not None and last_weight == this_weight: + last_text = result[-1].text + result[-1] = Fragment(last_text + ' ' + this_text, last_weight) + else: + result.append(x) + return result + + def flatten_internal(node, weight_scale, results, prefix): + #print(prefix + "flattening", node, "...") + if type(node) is pp.ParseResults: + for x in node: + results = flatten_internal(x, weight_scale, results, prefix+'pr') + #print(prefix, " ParseResults expanded, results is now", results) + elif issubclass(type(node), BaseFragment): + results.append(node) + #elif type(node) is Attention: + # #if node.weight < 1: + # # todo: inject a blend when flattening attention with weight <1" + # for c in node.children: + # results = flatten_internal(c, weight_scale*node.weight, results, prefix+' ') + elif type(node) is Blend: + flattened_subprompts = [] + #print(" flattening blend with prompts", node.prompts, "weights", node.weights) + for prompt in node.prompts: + # prompt is a list + flattened_subprompts = flatten_internal(prompt, weight_scale, flattened_subprompts, prefix+'B ') + results += [Blend(prompts=flattened_subprompts, weights=node.weights)] + elif type(node) is Prompt: + #print(prefix + "about to flatten Prompt with children", node.children) + flattened_prompt = [] + for child in node.children: + flattened_prompt = flatten_internal(child, weight_scale, flattened_prompt, prefix+'P ') + results += [FlattenedPrompt(parts=fuse_fragments(flattened_prompt))] + #print(prefix + "after flattening Prompt, results is", results) + else: + raise PromptParser.ParsingException(f"unhandled node type {type(node)} when flattening {node}") + #print(prefix + "-> after flattening", type(node), "results is", results) + return results + + #print("flattening", root) + + flattened_parts = [] + for part in root.prompts: + flattened_parts += flatten_internal(part, 1.0, [], ' C| ') + weights = root.weights + return Conjunction(flattened_parts, weights) + + + + def build_parser_logic(self): + + lparen = pp.Literal("(").suppress() + rparen = pp.Literal(")").suppress() + # accepts int or float notation, always maps to float + number = pyparsing.pyparsing_common.real | pp.Word(pp.nums).set_parse_action(pp.token_map(float)) + SPACE_CHARS = ' \t\n' + + prompt_part = pp.Forward() + word = pp.Word(pp.printables).set_parse_action(lambda x: Fragment(' '.join([s for s in x]))) + word.set_name("word") + word.set_debug(False) + + def make_fragment(x): + #print("### making fragment for", x) + if type(x) is str: + return Fragment(x) + elif type(x) is pp.ParseResults or type(x) is list: + return Fragment(' '.join([s for s in x])) + else: + raise PromptParser.ParsingException("Cannot make fragment from " + str(x)) + + + original_words = ( + (lparen + pp.Literal('"').suppress() + pp.CharsNotIn('"') + pp.Literal('"').suppress() + rparen).set_name('term1').set_debug(False) | + (pp.Literal('"').suppress() + pp.CharsNotIn('"') + pp.Literal('"').suppress()).set_name('term2').set_debug(False) | + (lparen + pp.CharsNotIn(')') + rparen).set_name('term3').set_debug(False) + ).set_name('original_words') + edited_words = ( + (pp.Literal('"').suppress() + pp.CharsNotIn('"') + pp.Literal('"').suppress()).set_name('termA').set_debug(False) | + pp.CharsNotIn(')').set_name('termB').set_debug(False) + ).set_name('edited_words') + cross_attention_substitute = original_words + \ + pp.Literal(".swap").suppress() + \ + lparen + edited_words + rparen + cross_attention_substitute.set_name('cross_attention_substitute') + + def make_cross_attention_substitute(x): + #print("making cacs for", x) + return CrossAttentionControlSubstitute(x[0], x[1]) + #print("made", cacs) + #return cacs + + cross_attention_substitute.set_parse_action(make_cross_attention_substitute) + + # simple fragments of text + prompt_part << (cross_attention_substitute + #| attention + | word + ) + prompt_part.set_debug(False) + prompt_part.set_name("prompt_part") + + # root prompt definition + prompt = pp.Group(pp.OneOrMore(prompt_part))\ + .set_parse_action(lambda x: Prompt(x[0])) + + # weighted blend of prompts + # ("promptA", "promptB").blend(a, b) where "promptA" and "promptB" are valid prompts and a and b are float or + # int weights. + # can specify more terms eg ("promptA", "promptB", "promptC").blend(a,b,c) + + def make_prompt_from_quoted_string(x): + #print(' got quoted prompt', x) + + x_unquoted = x[0][1:-1] + if len(x_unquoted.strip()) == 0: + # print(' b : just an empty string') + return Prompt([Fragment('')]) + # print(' b parsing ', c_unquoted) + x_parsed = prompt.parse_string(x_unquoted) + #print(" quoted prompt was parsed to", type(x_parsed),":", x_parsed) + return x_parsed[0] + + quoted_prompt = pp.dbl_quoted_string.set_parse_action(make_prompt_from_quoted_string) + quoted_prompt.set_name('quoted_prompt') + + blend_terms = pp.delimited_list(quoted_prompt).set_name('blend_terms') + blend_weights = pp.delimited_list(number).set_name('blend_weights') + blend = pp.Group(lparen + pp.Group(blend_terms) + rparen + + pp.Literal(".blend").suppress() + + lparen + pp.Group(blend_weights) + rparen).set_name('blend') + blend.set_debug(False) + + + blend.set_parse_action(lambda x: Blend(prompts=x[0][0], weights=x[0][1])) + + conjunction_terms = blend_terms.copy().set_name('conjunction_terms') + conjunction_weights = blend_weights.copy().set_name('conjunction_weights') + conjunction_with_parens_and_quotes = pp.Group(lparen + pp.Group(conjunction_terms) + rparen + + pp.Literal(".and").suppress() + + lparen + pp.Optional(pp.Group(conjunction_weights)) + rparen).set_name('conjunction') + def make_conjunction(x): + parts_raw = x[0][0] + weights = x[0][1] if len(x[0])>1 else [1.0]*len(parts_raw) + parts = [part for part in parts_raw] + return Conjunction(parts, weights) + conjunction_with_parens_and_quotes.set_parse_action(make_conjunction) + + implicit_conjunction = pp.OneOrMore(blend | prompt) + implicit_conjunction.set_parse_action(lambda x: Conjunction(x)) + + conjunction = conjunction_with_parens_and_quotes | implicit_conjunction + conjunction.set_debug(False) + + # top-level is a conjunction of one or more blends or prompts + return conjunction diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py index 4b62b5e393..57027b224c 100644 --- a/ldm/models/diffusion/ddpm.py +++ b/ldm/models/diffusion/ddpm.py @@ -820,21 +820,21 @@ class LatentDiffusion(DDPM): ) return self.scale_factor * z - def get_learned_conditioning(self, c): + def get_learned_conditioning(self, c, **kwargs): if self.cond_stage_forward is None: if hasattr(self.cond_stage_model, 'encode') and callable( self.cond_stage_model.encode ): c = self.cond_stage_model.encode( - c, embedding_manager=self.embedding_manager + c, embedding_manager=self.embedding_manager, **kwargs ) if isinstance(c, DiagonalGaussianDistribution): c = c.mode() else: - c = self.cond_stage_model(c) + c = self.cond_stage_model(c, **kwargs) else: assert hasattr(self.cond_stage_model, self.cond_stage_forward) - c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) + c = getattr(self.cond_stage_model, self.cond_stage_forward)(c, **kwargs) return c def meshgrid(self, h, w): diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index 12ef737134..8f4ad26119 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -1,3 +1,5 @@ +import math + import torch import torch.nn as nn from functools import partial @@ -449,11 +451,23 @@ class FrozenCLIPEmbedder(AbstractEncoder): tokens = batch_encoding['input_ids'].to(self.device) z = self.transformer(input_ids=tokens, **kwargs) - return z, tokens + if kwargs.get('return_tokens', False): + return z, tokens + else: + return z def encode(self, text, **kwargs): return self(text, **kwargs) +class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder): + @classmethod + def apply_embedding_weights(self, embeddings: torch.Tensor, per_embedding_weights: list[float], normalize:bool) -> torch.Tensor: + per_embedding_weights = torch.tensor(per_embedding_weights, dtype=embeddings.dtype, device=embeddings.device) + if normalize: + per_embedding_weights = per_embedding_weights / torch.sum(per_embedding_weights) + reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1, 1,)) + return torch.sum(embeddings * reshaped_weights, dim=1) + class FrozenCLIPTextEmbedder(nn.Module): """ diff --git a/tests/test_prompt_parser.py b/tests/test_prompt_parser.py new file mode 100644 index 0000000000..2ef56c47ae --- /dev/null +++ b/tests/test_prompt_parser.py @@ -0,0 +1,173 @@ +import unittest + +from ldm.invoke.prompt_parser import PromptParser, Blend, Conjunction, FlattenedPrompt, CrossAttentionControlSubstitute + + +def parse_prompt(prompt_string): + pp = PromptParser() + #print(f"parsing '{prompt_string}'") + parse_result = pp.parse(prompt_string) + #print(f"-> parsed '{prompt_string}' to {parse_result}") + return parse_result + +class PromptParserTestCase(unittest.TestCase): + + def test_empty(self): + self.assertEqual(Conjunction([FlattenedPrompt([('', 1)])]), parse_prompt('')) + + def test_basic(self): + self.assertEqual(Conjunction([FlattenedPrompt([('fire (flames)', 1)])]), parse_prompt("fire (flames)")) + self.assertEqual(Conjunction([FlattenedPrompt([("fire flames", 1)])]), parse_prompt("fire flames")) + self.assertEqual(Conjunction([FlattenedPrompt([("fire, flames", 1)])]), parse_prompt("fire, flames")) + self.assertEqual(Conjunction([FlattenedPrompt([("fire, flames , fire", 1)])]), parse_prompt("fire, flames , fire")) + + def test_attention(self): + self.assertEqual(Conjunction([FlattenedPrompt([('flames', 0.5)])]), parse_prompt("0.5(flames)")) + self.assertEqual(Conjunction([FlattenedPrompt([('fire flames', 0.5)])]), parse_prompt("0.5(fire flames)")) + self.assertEqual(Conjunction([FlattenedPrompt([('flames', 1.1)])]), parse_prompt("+(flames)")) + self.assertEqual(Conjunction([FlattenedPrompt([('flames', 0.9)])]), parse_prompt("-(flames)")) + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1), ('flames', 0.5)])]), parse_prompt("fire 0.5(flames)")) + self.assertEqual(Conjunction([FlattenedPrompt([('flames', pow(1.1, 2))])]), parse_prompt("++(flames)")) + self.assertEqual(Conjunction([FlattenedPrompt([('flames', pow(0.9, 2))])]), parse_prompt("--(flames)")) + self.assertEqual(Conjunction([FlattenedPrompt([('flowers', pow(0.9, 3)), ('flames', pow(1.1, 3))])]), parse_prompt("---(flowers) +++flames")) + self.assertEqual(Conjunction([FlattenedPrompt([('flowers', pow(0.9, 3)), ('flames', pow(1.1, 3))])]), parse_prompt("---(flowers) +++flames")) + self.assertEqual(Conjunction([FlattenedPrompt([('flowers', pow(0.9, 3)), ('flames+', pow(1.1, 3))])]), + parse_prompt("---(flowers) +++flames+")) + self.assertEqual(Conjunction([FlattenedPrompt([('pretty flowers', 1.1)])]), + parse_prompt("+(pretty flowers)")) + self.assertEqual(Conjunction([FlattenedPrompt([('pretty flowers', 1.1), (', the flames are too hot', 1)])]), + parse_prompt("+(pretty flowers), the flames are too hot")) + + def test_no_parens_attention_runon(self): + self.assertEqual(Conjunction([FlattenedPrompt([('fire', pow(1.1, 2)), ('flames', 1.0)])]), parse_prompt("++fire flames")) + self.assertEqual(Conjunction([FlattenedPrompt([('fire', pow(0.9, 2)), ('flames', 1.0)])]), parse_prompt("--fire flames")) + self.assertEqual(Conjunction([FlattenedPrompt([('flowers', 1.0), ('fire', pow(1.1, 2)), ('flames', 1.0)])]), parse_prompt("flowers ++fire flames")) + self.assertEqual(Conjunction([FlattenedPrompt([('flowers', 1.0), ('fire', pow(0.9, 2)), ('flames', 1.0)])]), parse_prompt("flowers --fire flames")) + + + def test_explicit_conjunction(self): + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])]), parse_prompt('("fire", "flames").and(1,1)')) + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])]), parse_prompt('("fire", "flames").and()')) + self.assertEqual( + Conjunction([FlattenedPrompt([('fire flames', 1.0)]), FlattenedPrompt([('mountain man', 1.0)])]), parse_prompt('("fire flames", "mountain man").and()')) + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 2.0)]), FlattenedPrompt([('flames', 0.9)])]), parse_prompt('("2.0(fire)", "-flames").and()')) + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)]), + FlattenedPrompt([('mountain man', 1.0)])]), parse_prompt('("fire", "flames", "mountain man").and()')) + + def test_conjunction_weights(self): + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])], weights=[2.0,1.0]), parse_prompt('("fire", "flames").and(2,1)')) + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])], weights=[1.0,2.0]), parse_prompt('("fire", "flames").and(1,2)')) + + with self.assertRaises(PromptParser.ParsingException): + parse_prompt('("fire", "flames").and(2)') + parse_prompt('("fire", "flames").and(2,1,2)') + + def test_complex_conjunction(self): + self.assertEqual(Conjunction([FlattenedPrompt([("mountain man", 1.0)]), FlattenedPrompt([("a person with a hat", 1.0), ("riding a bicycle", pow(1.1,2))])], weights=[0.5, 0.5]), + parse_prompt("(\"mountain man\", \"a person with a hat ++(riding a bicycle)\").and(0.5, 0.5)")) + + def test_badly_formed(self): + def make_untouched_prompt(prompt): + return Conjunction([FlattenedPrompt([(prompt, 1.0)])]) + + def assert_if_prompt_string_not_untouched(prompt): + self.assertEqual(make_untouched_prompt(prompt), parse_prompt(prompt)) + + assert_if_prompt_string_not_untouched('a test prompt') + assert_if_prompt_string_not_untouched('a badly (formed test prompt') + assert_if_prompt_string_not_untouched('a badly formed test+ prompt') + assert_if_prompt_string_not_untouched('a badly (formed test+ prompt') + assert_if_prompt_string_not_untouched('a badly (formed test+ )prompt') + assert_if_prompt_string_not_untouched('a badly (formed test+ )prompt') + assert_if_prompt_string_not_untouched('(((a badly (formed test+ )prompt') + assert_if_prompt_string_not_untouched('(a (ba)dly (f)ormed test+ prompt') + self.assertEqual(Conjunction([FlattenedPrompt([('(a (ba)dly (f)ormed test+', 1.0), ('prompt', 1.1)])]), + parse_prompt('(a (ba)dly (f)ormed test+ +prompt')) + self.assertEqual(Conjunction([Blend([FlattenedPrompt([('((a badly (formed test+', 1.0)])], weights=[1.0])]), + parse_prompt('("((a badly (formed test+ ").blend(1.0)')) + + def test_blend(self): + self.assertEqual(Conjunction( + [Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('fire flames', 1.0)])], [0.7, 0.3])]), + parse_prompt("(\"fire\", \"fire flames\").blend(0.7, 0.3)") + ) + self.assertEqual(Conjunction([Blend( + [FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('fire flames', 1.0)]), FlattenedPrompt([('hi', 1.0)])], + [0.7, 0.3, 1.0])]), + parse_prompt("(\"fire\", \"fire flames\", \"hi\").blend(0.7, 0.3, 1.0)") + ) + self.assertEqual(Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), + FlattenedPrompt([('fire flames', 1.0), ('hot', pow(1.1, 2))]), + FlattenedPrompt([('hi', 1.0)])], + weights=[0.7, 0.3, 1.0])]), + parse_prompt("(\"fire\", \"fire flames ++(hot)\", \"hi\").blend(0.7, 0.3, 1.0)") + ) + # blend a single entry is not a failure + self.assertEqual(Conjunction([Blend([FlattenedPrompt([('fire', 1.0)])], [0.7])]), + parse_prompt("(\"fire\").blend(0.7)") + ) + # blend with empty + self.assertEqual( + Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]), + parse_prompt("(\"fire\", \"\").blend(0.7, 1)") + ) + self.assertEqual( + Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]), + parse_prompt("(\"fire\", \" \").blend(0.7, 1)") + ) + self.assertEqual( + Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]), + parse_prompt("(\"fire\", \" \").blend(0.7, 1)") + ) + self.assertEqual( + Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([(',', 1.0)])], [0.7, 1.0])]), + parse_prompt("(\"fire\", \" , \").blend(0.7, 1)") + ) + + + def test_nested(self): + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0), ('flames', 2.0), ('trees', 3.0)])]), + parse_prompt('fire 2.0(flames 1.5(trees))')) + self.assertEqual(Conjunction([Blend(prompts=[FlattenedPrompt([('fire', 1.0), ('flames', 1.2100000000000002)]), + FlattenedPrompt([('mountain', 1.0), ('man', 2.0)])], + weights=[1.0, 1.0])]), + parse_prompt('("fire ++(flames)", "mountain 2(man)").blend(1,1)')) + + def test_cross_attention_control(self): + fire_flames_to_trees = Conjunction([FlattenedPrompt([('fire', 1.0), \ + CrossAttentionControlSubstitute('flames', 'trees')])]) + self.assertEqual(fire_flames_to_trees, parse_prompt('fire "flames".swap(trees)')) + self.assertEqual(fire_flames_to_trees, parse_prompt('fire (flames).swap(trees)')) + self.assertEqual(fire_flames_to_trees, parse_prompt('fire ("flames").swap(trees)')) + self.assertEqual(fire_flames_to_trees, parse_prompt('fire "flames".swap("trees")')) + self.assertEqual(fire_flames_to_trees, parse_prompt('fire (flames).swap("trees")')) + self.assertEqual(fire_flames_to_trees, parse_prompt('fire ("flames").swap("trees")')) + + fire_flames_to_trees_and_houses = Conjunction([FlattenedPrompt([('fire', 1.0), \ + CrossAttentionControlSubstitute('flames', 'trees and houses')])]) + self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire ("flames").swap("trees and houses")')) + self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire (flames).swap("trees and houses")')) + self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire "flames".swap("trees and houses")')) + + trees_and_houses_to_flames = Conjunction([FlattenedPrompt([('fire', 1.0), \ + CrossAttentionControlSubstitute('trees and houses', 'flames')])]) + self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire ("trees and houses").swap("flames")')) + self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire (trees and houses).swap("flames")')) + self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire "trees and houses".swap("flames")')) + self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire ("trees and houses").swap(flames)')) + self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire (trees and houses).swap(flames)')) + self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire "trees and houses".swap(flames)')) + + flames_to_trees_fire = Conjunction([FlattenedPrompt([ + CrossAttentionControlSubstitute('flames', 'trees'), + (', fire', 1.0)])]) + self.assertEqual(flames_to_trees_fire, parse_prompt('"flames".swap(trees), fire')) + self.assertEqual(flames_to_trees_fire, parse_prompt('(flames).swap(trees), fire ')) + self.assertEqual(flames_to_trees_fire, parse_prompt('("flames").swap(trees), fire ')) + self.assertEqual(flames_to_trees_fire, parse_prompt('"flames".swap("trees"), fire')) + self.assertEqual(flames_to_trees_fire, parse_prompt('(flames).swap("trees"), fire')) + self.assertEqual(flames_to_trees_fire, parse_prompt('("flames").swap("trees"), fire')) + + +if __name__ == '__main__': + unittest.main() From c9d27634b4b557b2bfd012f857ea2bb2f3a40e51 Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Thu, 20 Oct 2022 12:01:48 +0200 Subject: [PATCH 22/54] bring in prompt parser from fix-prompts branch attention is parsed but ignored, blends old syntax doesn't work, conjunctions are parsed but ignored, the only part that's used here is the new .blend() syntax and cross-attention control using .swap() --- ldm/invoke/conditioning.py | 24 ++++--- ldm/invoke/generator/img2img.py | 4 +- ldm/invoke/generator/txt2img.py | 4 +- ldm/invoke/generator/txt2img2img.py | 6 +- ldm/invoke/prompt_parser.py | 108 ++++++++++++++++++++++------ ldm/models/diffusion/ddim.py | 6 +- ldm/models/diffusion/ksampler.py | 14 ++-- ldm/models/diffusion/plms.py | 6 +- ldm/modules/encoders/modules.py | 9 ++- tests/test_prompt_parser.py | 49 ++++++++++--- 10 files changed, 169 insertions(+), 61 deletions(-) diff --git a/ldm/invoke/conditioning.py b/ldm/invoke/conditioning.py index b7c8e55e66..fb6d8d443e 100644 --- a/ldm/invoke/conditioning.py +++ b/ldm/invoke/conditioning.py @@ -4,7 +4,7 @@ weighted subprompts. Useful function exports: -get_uc_and_c() get the conditioned and unconditioned latent +get_uc_and_c_and_ec() get the conditioned and unconditioned latent, and edited conditioning if we're doing cross-attention control split_weighted_subpromopts() split subprompts, normalize and weight them log_tokenization() print out colour-coded tokens and warn if truncated @@ -39,8 +39,10 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n pp = PromptParser() - parsed_prompt: Union[FlattenedPrompt, Blend] = pp.parse(prompt_string_cleaned) - parsed_negative_prompt: FlattenedPrompt = pp.parse(unconditioned_words) + # we don't support conjunctions for now + parsed_prompt: Union[FlattenedPrompt, Blend] = pp.parse(prompt_string_cleaned).prompts[0] + parsed_negative_prompt: FlattenedPrompt = pp.parse(unconditioned_words).prompts[0] + print("parsed prompt to", parsed_prompt) conditioning = None edited_conditioning = None @@ -50,7 +52,7 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n blend: Blend = parsed_prompt embeddings_to_blend = None for flattened_prompt in blend.prompts: - this_embedding = make_embeddings_for_flattened_prompt(model, flattened_prompt) + this_embedding = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt) embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat( (embeddings_to_blend, this_embedding)) conditioning, _ = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0), @@ -72,16 +74,16 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n # regular fragment original_prompt.append(fragment) edited_prompt.append(fragment) - original_embeddings, original_tokens = make_embeddings_for_flattened_prompt(model, original_prompt) - edited_embeddings, edited_tokens = make_embeddings_for_flattened_prompt(model, edited_prompt) + original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, original_prompt) + edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, edited_prompt) conditioning = original_embeddings edited_conditioning = edited_embeddings edit_opcodes = build_token_edit_opcodes(original_tokens, edited_tokens) else: - conditioning, _ = make_embeddings_for_flattened_prompt(model, flattened_prompt) + conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt) - unconditioning = make_embeddings_for_flattened_prompt(parsed_negative_prompt) + unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, parsed_negative_prompt) return (unconditioning, conditioning, edited_conditioning, edit_opcodes) @@ -91,11 +93,11 @@ def build_token_edit_opcodes(original_tokens, edited_tokens): return SequenceMatcher(None, original_tokens, edited_tokens).get_opcodes() -def make_embeddings_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt): +def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt): if type(flattened_prompt) is not FlattenedPrompt: raise f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead" - fragments = [x[0] for x in flattened_prompt.children] - embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True) + fragments = [x.text for x in flattened_prompt.children] + embeddings, tokens = model.get_learned_conditioning([' '.join(fragments)], return_tokens=True) return embeddings, tokens diff --git a/ldm/invoke/generator/img2img.py b/ldm/invoke/generator/img2img.py index 0a12bd90e5..6fa0d0c6dd 100644 --- a/ldm/invoke/generator/img2img.py +++ b/ldm/invoke/generator/img2img.py @@ -34,7 +34,7 @@ class Img2Img(Generator): t_enc = int(strength * steps) uc, c, ec, edit_opcodes = conditioning - structured_conditioning = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes) + extra_conditioning_info = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes) def make_image(x_T): # encode (scaled latent) @@ -52,7 +52,7 @@ class Img2Img(Generator): unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc, init_latent = self.init_latent, - structured_conditioning = structured_conditioning + extra_conditioning_info = extra_conditioning_info # changes how noising is performed in ksampler ) diff --git a/ldm/invoke/generator/txt2img.py b/ldm/invoke/generator/txt2img.py index 6e158562c5..657cccc592 100644 --- a/ldm/invoke/generator/txt2img.py +++ b/ldm/invoke/generator/txt2img.py @@ -22,7 +22,7 @@ class Txt2Img(Generator): """ self.perlin = perlin uc, c, ec, edit_opcodes = conditioning - structured_conditioning = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes) + extra_conditioning_info = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes) @torch.no_grad() def make_image(x_T): @@ -46,7 +46,7 @@ class Txt2Img(Generator): verbose = False, unconditional_guidance_scale = cfg_scale, unconditional_conditioning = uc, - structured_conditioning = structured_conditioning, + extra_conditioning_info = extra_conditioning_info, eta = ddim_eta, img_callback = step_callback, threshold = threshold, diff --git a/ldm/invoke/generator/txt2img2img.py b/ldm/invoke/generator/txt2img2img.py index 52a14aae74..64d0468418 100644 --- a/ldm/invoke/generator/txt2img2img.py +++ b/ldm/invoke/generator/txt2img2img.py @@ -24,7 +24,7 @@ class Txt2Img2Img(Generator): kwargs are 'width' and 'height' """ uc, c, ec, edit_opcodes = conditioning - structured_conditioning = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes) + extra_conditioning_info = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes) @torch.no_grad() def make_image(x_T): @@ -63,7 +63,7 @@ class Txt2Img2Img(Generator): unconditional_conditioning = uc, eta = ddim_eta, img_callback = step_callback, - structured_conditioning = structured_conditioning + extra_conditioning_info = extra_conditioning_info ) print( @@ -97,7 +97,7 @@ class Txt2Img2Img(Generator): img_callback = step_callback, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc, - structured_conditioning = structured_conditioning + extra_conditioning_info = extra_conditioning_info ) if self.free_gpu_mem: diff --git a/ldm/invoke/prompt_parser.py b/ldm/invoke/prompt_parser.py index 9dd0f80ade..c13175a488 100644 --- a/ldm/invoke/prompt_parser.py +++ b/ldm/invoke/prompt_parser.py @@ -7,7 +7,7 @@ class Prompt(): def __init__(self, parts: list): for c in parts: - if not issubclass(type(c), BaseFragment): + if type(c) is not Attention and not issubclass(type(c), BaseFragment): raise PromptParser.ParsingException(f"Prompt cannot contain {type(c)}, only {BaseFragment.__subclasses__()} are allowed") self.children = parts def __repr__(self): @@ -56,6 +56,17 @@ class Fragment(BaseFragment): and other.text == self.text \ and other.weight == self.weight +class Attention(): + def __init__(self, weight: float, children: list): + self.weight = weight + self.children = children + #print(f"A: requested attention '{children}' to {weight}") + + def __repr__(self): + return f"Attention:'{self.children}' @ {self.weight}" + def __eq__(self, other): + return type(other) is Attention and other.weight == self.weight and other.fragment == self.fragment + class CrossAttentionControlledFragment(BaseFragment): pass @@ -65,7 +76,7 @@ class CrossAttentionControlSubstitute(CrossAttentionControlledFragment): self.edited = edited def __repr__(self): - return f"CrossAttentionControlSubstitute:('{self.original}'->'{self.edited}')" + return f"CrossAttentionControlSubstitute:({self.original}->{self.edited})" def __eq__(self, other): return type(other) is CrossAttentionControlSubstitute \ and other.original == self.original \ @@ -137,7 +148,7 @@ class PromptParser(): self.root = self.build_parser_logic() - def parse(self, prompt: str) -> [list]: + def parse(self, prompt: str) -> Conjunction: ''' :param prompt: The prompt string to parse :return: a tuple @@ -181,13 +192,17 @@ class PromptParser(): for x in node: results = flatten_internal(x, weight_scale, results, prefix+'pr') #print(prefix, " ParseResults expanded, results is now", results) - elif issubclass(type(node), BaseFragment): - results.append(node) - #elif type(node) is Attention: - # #if node.weight < 1: - # # todo: inject a blend when flattening attention with weight <1" - # for c in node.children: - # results = flatten_internal(c, weight_scale*node.weight, results, prefix+' ') + elif type(node) is Attention: + # if node.weight < 1: + # todo: inject a blend when flattening attention with weight <1" + for c in node.children: + results = flatten_internal(c, weight_scale * node.weight, results, prefix + ' ') + elif type(node) is Fragment: + results += [Fragment(node.text, node.weight*weight_scale)] + elif type(node) is CrossAttentionControlSubstitute: + original = flatten_internal(node.original, weight_scale, [], ' CAo ') + edited = flatten_internal(node.edited, weight_scale, [], ' CAe ') + results += [CrossAttentionControlSubstitute(original, edited)] elif type(node) is Blend: flattened_subprompts = [] #print(" flattening blend with prompts", node.prompts, "weights", node.weights) @@ -204,7 +219,7 @@ class PromptParser(): #print(prefix + "after flattening Prompt, results is", results) else: raise PromptParser.ParsingException(f"unhandled node type {type(node)} when flattening {node}") - #print(prefix + "-> after flattening", type(node), "results is", results) + print(prefix + "-> after flattening", type(node), "results is", results) return results #print("flattening", root) @@ -239,32 +254,83 @@ class PromptParser(): else: raise PromptParser.ParsingException("Cannot make fragment from " + str(x)) + # attention control of the form +(phrase) / -(phrase) / (phrase) + # phrase can be multiple words, can have multiple +/- signs to increase the effect or type a floating point or integer weight + attention = pp.Forward() + attention_head = (number | pp.Word('+') | pp.Word('-'))\ + .set_name("attention_head")\ + .set_debug(False) + fragment_inside_attention = pp.CharsNotIn(SPACE_CHARS+'()')\ + .set_parse_action(make_fragment)\ + .set_name("fragment_inside_attention")\ + .set_debug(False) + attention_with_parens = pp.Forward() + attention_with_parens_body = pp.nested_expr(content=pp.delimited_list((attention_with_parens | fragment_inside_attention), delim=SPACE_CHARS)) + attention_with_parens << (attention_head + attention_with_parens_body) + + def make_attention(x): + # print("making Attention from parsing with args", x0, x1) + weight = 1 + # number(str) + if type(x[0]) is float or type(x[0]) is int: + weight = float(x[0]) + # +(str) or -(str) or +str or -str + elif type(x[0]) is str: + base = self.attention_plus_base if x[0][0] == '+' else self.attention_minus_base + weight = pow(base, len(x[0])) + # print("Making attention with children of type", [str(type(x)) for x in x1]) + return Attention(weight=weight, children=x[1]) + + attention_with_parens.set_parse_action(make_attention)\ + .set_name("attention_with_parens")\ + .set_debug(False) + + # attention control of the form ++word --word (no parens) + attention_without_parens = ( + (pp.Word('+') | pp.Word('-')) + + pp.CharsNotIn(SPACE_CHARS+'()').set_parse_action(lambda x: [[make_fragment(x)]]) + )\ + .set_name("attention_without_parens")\ + .set_debug(False) + attention_without_parens.set_parse_action(make_attention) + + attention << (attention_with_parens | attention_without_parens)\ + .set_name("attention")\ + .set_debug(False) + + # cross-attention control + empty_string = ((lparen + rparen) | + pp.Literal('""').suppress() | + (lparen + pp.Literal('""').suppress() + rparen) + ).set_parse_action(lambda x: Fragment("")) original_words = ( - (lparen + pp.Literal('"').suppress() + pp.CharsNotIn('"') + pp.Literal('"').suppress() + rparen).set_name('term1').set_debug(False) | - (pp.Literal('"').suppress() + pp.CharsNotIn('"') + pp.Literal('"').suppress()).set_name('term2').set_debug(False) | - (lparen + pp.CharsNotIn(')') + rparen).set_name('term3').set_debug(False) + (lparen + pp.Literal('"').suppress() + (pp.OneOrMore(attention) | pp.CharsNotIn('"').set_parse_action(make_fragment)) + pp.Literal('"').suppress() + rparen).set_name('term1').set_debug(False) | + (pp.Literal('"').suppress() + (pp.OneOrMore(attention) | pp.CharsNotIn('"').set_parse_action(make_fragment)) + pp.Literal('"').suppress()).set_name('term2').set_debug(False) | + (lparen + (pp.OneOrMore(attention) | pp.CharsNotIn(')').set_parse_action(make_fragment)) + rparen).set_name('term3').set_debug(False) ).set_name('original_words') edited_words = ( - (pp.Literal('"').suppress() + pp.CharsNotIn('"') + pp.Literal('"').suppress()).set_name('termA').set_debug(False) | - pp.CharsNotIn(')').set_name('termB').set_debug(False) + (pp.Literal('"').suppress() + (pp.OneOrMore(attention) | pp.CharsNotIn('"').set_parse_action(make_fragment)) + pp.Literal('"').suppress()).set_name('termA').set_debug(False) | + pp.Literal('""').suppress().set_parse_action(lambda x: Fragment("")) | + (pp.OneOrMore(attention) | pp.CharsNotIn(')').set_parse_action(make_fragment)).set_name('termB').set_debug(True) ).set_name('edited_words') - cross_attention_substitute = original_words + \ + cross_attention_substitute = (empty_string | original_words) + \ pp.Literal(".swap").suppress() + \ - lparen + edited_words + rparen + (empty_string | (lparen + edited_words + rparen) + ) cross_attention_substitute.set_name('cross_attention_substitute') def make_cross_attention_substitute(x): #print("making cacs for", x) - return CrossAttentionControlSubstitute(x[0], x[1]) + cacs = CrossAttentionControlSubstitute(x[0], x[1]) #print("made", cacs) - #return cacs + return cacs cross_attention_substitute.set_parse_action(make_cross_attention_substitute) # simple fragments of text prompt_part << (cross_attention_substitute - #| attention + | attention | word ) prompt_part.set_debug(False) diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py index 0ab6911247..98219fb62e 100644 --- a/ldm/models/diffusion/ddim.py +++ b/ldm/models/diffusion/ddim.py @@ -16,10 +16,10 @@ class DDIMSampler(Sampler): def prepare_to_sample(self, t_enc, **kwargs): super().prepare_to_sample(t_enc, **kwargs) - structured_conditioning = kwargs.get('structured_conditioning', None) + extra_conditioning_info = kwargs.get('extra_conditioning_info', None) - if structured_conditioning is not None and structured_conditioning.wants_cross_attention_control: - self.invokeai_diffuser.setup_cross_attention_control(structured_conditioning) + if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: + self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info) else: self.invokeai_diffuser.remove_cross_attention_control() diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index a8291e32c1..8c858757eb 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -34,10 +34,10 @@ class CFGDenoiser(nn.Module): def prepare_to_sample(self, t_enc, **kwargs): - structured_conditioning = kwargs.get('structured_conditioning', None) + extra_conditioning_info = kwargs.get('extra_conditioning_info', None) - if structured_conditioning is not None and structured_conditioning.wants_cross_attention_control: - self.invokeai_diffuser.setup_cross_attention_control(structured_conditioning) + if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: + self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info) else: self.invokeai_diffuser.remove_cross_attention_control() @@ -164,7 +164,7 @@ class KSampler(Sampler): log_every_t=100, unconditional_guidance_scale=1.0, unconditional_conditioning=None, - structured_conditioning=None, + extra_conditioning_info=None, threshold = 0, perlin = 0, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... @@ -197,7 +197,7 @@ class KSampler(Sampler): x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10)) - model_wrap_cfg.prepare_to_sample(S, structured_conditioning=structured_conditioning) + model_wrap_cfg.prepare_to_sample(S, extra_conditioning_info=extra_conditioning_info) extra_args = { 'cond': conditioning, 'uncond': unconditional_conditioning, @@ -224,7 +224,7 @@ class KSampler(Sampler): index, unconditional_guidance_scale=1.0, unconditional_conditioning=None, - structured_conditioning=None, + extra_conditioning_info=None, **kwargs, ): if self.model_wrap is None: @@ -250,7 +250,7 @@ class KSampler(Sampler): # so the actual formula for indexing into sigmas: # sigma_index = (steps-index) s_index = t_enc - index - 1 - self.model_wrap.prepare_to_sample(s_index, structured_conditioning=structured_conditioning) + self.model_wrap.prepare_to_sample(s_index, extra_conditioning_info=extra_conditioning_info) img = K.sampling.__dict__[f'_{self.schedule}']( self.model_wrap, img, diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py index 98975525ed..f58e2c3220 100644 --- a/ldm/models/diffusion/plms.py +++ b/ldm/models/diffusion/plms.py @@ -20,10 +20,10 @@ class PLMSSampler(Sampler): def prepare_to_sample(self, t_enc, **kwargs): super().prepare_to_sample(t_enc, **kwargs) - structured_conditioning = kwargs.get('structured_conditioning', None) + extra_conditioning_info = kwargs.get('extra_conditioning_info', None) - if structured_conditioning is not None and structured_conditioning.wants_cross_attention_control: - self.invokeai_diffuser.setup_cross_attention_control(structured_conditioning) + if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: + self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info) else: self.invokeai_diffuser.remove_cross_attention_control() diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index 8f4ad26119..18878af443 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -439,6 +439,13 @@ class FrozenCLIPEmbedder(AbstractEncoder): param.requires_grad = False def forward(self, text, **kwargs): + + should_return_tokens = False + if 'return_tokens' in kwargs: + should_return_tokens = kwargs.get('return_tokens', False) + # self.transformer doesn't like having extra kwargs + kwargs.pop('return_tokens') + batch_encoding = self.tokenizer( text, truncation=True, @@ -451,7 +458,7 @@ class FrozenCLIPEmbedder(AbstractEncoder): tokens = batch_encoding['input_ids'].to(self.device) z = self.transformer(input_ids=tokens, **kwargs) - if kwargs.get('return_tokens', False): + if should_return_tokens: return z, tokens else: return z diff --git a/tests/test_prompt_parser.py b/tests/test_prompt_parser.py index 2ef56c47ae..99f4db33a1 100644 --- a/tests/test_prompt_parser.py +++ b/tests/test_prompt_parser.py @@ -1,6 +1,7 @@ import unittest -from ldm.invoke.prompt_parser import PromptParser, Blend, Conjunction, FlattenedPrompt, CrossAttentionControlSubstitute +from ldm.invoke.prompt_parser import PromptParser, Blend, Conjunction, FlattenedPrompt, CrossAttentionControlSubstitute, \ + Fragment def parse_prompt(prompt_string): @@ -135,7 +136,7 @@ class PromptParserTestCase(unittest.TestCase): def test_cross_attention_control(self): fire_flames_to_trees = Conjunction([FlattenedPrompt([('fire', 1.0), \ - CrossAttentionControlSubstitute('flames', 'trees')])]) + CrossAttentionControlSubstitute([Fragment('flames', 1)], [Fragment('trees', 1)])])]) self.assertEqual(fire_flames_to_trees, parse_prompt('fire "flames".swap(trees)')) self.assertEqual(fire_flames_to_trees, parse_prompt('fire (flames).swap(trees)')) self.assertEqual(fire_flames_to_trees, parse_prompt('fire ("flames").swap(trees)')) @@ -144,13 +145,13 @@ class PromptParserTestCase(unittest.TestCase): self.assertEqual(fire_flames_to_trees, parse_prompt('fire ("flames").swap("trees")')) fire_flames_to_trees_and_houses = Conjunction([FlattenedPrompt([('fire', 1.0), \ - CrossAttentionControlSubstitute('flames', 'trees and houses')])]) + CrossAttentionControlSubstitute([Fragment('flames', 1)], [Fragment('trees and houses', 1)])])]) self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire ("flames").swap("trees and houses")')) self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire (flames).swap("trees and houses")')) self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire "flames".swap("trees and houses")')) trees_and_houses_to_flames = Conjunction([FlattenedPrompt([('fire', 1.0), \ - CrossAttentionControlSubstitute('trees and houses', 'flames')])]) + CrossAttentionControlSubstitute([Fragment('trees and houses', 1)], [Fragment('flames',1)])])]) self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire ("trees and houses").swap("flames")')) self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire (trees and houses).swap("flames")')) self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire "trees and houses".swap("flames")')) @@ -159,14 +160,46 @@ class PromptParserTestCase(unittest.TestCase): self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire "trees and houses".swap(flames)')) flames_to_trees_fire = Conjunction([FlattenedPrompt([ - CrossAttentionControlSubstitute('flames', 'trees'), + CrossAttentionControlSubstitute([Fragment('flames',1)], [Fragment('trees',1)]), (', fire', 1.0)])]) - self.assertEqual(flames_to_trees_fire, parse_prompt('"flames".swap(trees), fire')) - self.assertEqual(flames_to_trees_fire, parse_prompt('(flames).swap(trees), fire ')) - self.assertEqual(flames_to_trees_fire, parse_prompt('("flames").swap(trees), fire ')) self.assertEqual(flames_to_trees_fire, parse_prompt('"flames".swap("trees"), fire')) self.assertEqual(flames_to_trees_fire, parse_prompt('(flames).swap("trees"), fire')) self.assertEqual(flames_to_trees_fire, parse_prompt('("flames").swap("trees"), fire')) + self.assertEqual(flames_to_trees_fire, parse_prompt('"flames".swap(trees), fire')) + self.assertEqual(flames_to_trees_fire, parse_prompt('(flames).swap(trees), fire ')) + self.assertEqual(flames_to_trees_fire, parse_prompt('("flames").swap(trees), fire ')) + + + self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1), + CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]), + parse_prompt('a forest landscape "".swap("in winter")')) + self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1), + CrossAttentionControlSubstitute([Fragment(' ',1)], [Fragment('in winter',1)])])]), + parse_prompt('a forest landscape " ".swap("in winter")')) + + self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1), + CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]), + parse_prompt('a forest landscape "in winter".swap("")')) + self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1), + CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]), + parse_prompt('a forest landscape "in winter".swap()')) + self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1), + CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment(' ',1)])])]), + parse_prompt('a forest landscape "in winter".swap(" ")')) + + def test_cross_attention_control_with_attention(self): + flames_to_trees_fire = Conjunction([FlattenedPrompt([ + CrossAttentionControlSubstitute([Fragment('flames',0.5)], [Fragment('trees',0.7)]), + Fragment(',', 1), Fragment('fire', 2.0)])]) + self.assertEqual(flames_to_trees_fire, parse_prompt('"0.5(flames)".swap("0.7(trees)"), 2.0(fire)')) + flames_to_trees_fire = Conjunction([FlattenedPrompt([ + CrossAttentionControlSubstitute([Fragment('fire',0.5), Fragment('flames',0.25)], [Fragment('trees',0.7)]), + Fragment(',', 1), Fragment('fire', 2.0)])]) + self.assertEqual(flames_to_trees_fire, parse_prompt('"0.5(fire 0.5(flames))".swap("0.7(trees)"), 2.0(fire)')) + flames_to_trees_fire = Conjunction([FlattenedPrompt([ + CrossAttentionControlSubstitute([Fragment('fire',0.5), Fragment('flames',0.25)], [Fragment('trees',0.7), Fragment('houses', 1)]), + Fragment(',', 1), Fragment('fire', 2.0)])]) + self.assertEqual(flames_to_trees_fire, parse_prompt('"0.5(fire 0.5(flames))".swap("0.7(trees) houses"), 2.0(fire)')) if __name__ == '__main__': From da223dfe819efe96b3be99c158c08c96c619337a Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Thu, 20 Oct 2022 15:56:46 +0200 Subject: [PATCH 23/54] wip re-writing parts of prompt parser --- ldm/invoke/prompt_parser.py | 80 +++++++++++++++++++++++-------------- tests/test_prompt_parser.py | 3 ++ 2 files changed, 52 insertions(+), 31 deletions(-) diff --git a/ldm/invoke/prompt_parser.py b/ldm/invoke/prompt_parser.py index c13175a488..abd9ce726c 100644 --- a/ldm/invoke/prompt_parser.py +++ b/ldm/invoke/prompt_parser.py @@ -1,3 +1,5 @@ +import string + import pyparsing import pyparsing as pp from pyparsing import original_text_for @@ -200,8 +202,8 @@ class PromptParser(): elif type(node) is Fragment: results += [Fragment(node.text, node.weight*weight_scale)] elif type(node) is CrossAttentionControlSubstitute: - original = flatten_internal(node.original, weight_scale, [], ' CAo ') - edited = flatten_internal(node.edited, weight_scale, [], ' CAe ') + original = flatten_internal(node.original, weight_scale, [], prefix + ' CAo ') + edited = flatten_internal(node.edited, weight_scale, [], prefix + ' CAe ') results += [CrossAttentionControlSubstitute(original, edited)] elif type(node) is Blend: flattened_subprompts = [] @@ -236,24 +238,46 @@ class PromptParser(): lparen = pp.Literal("(").suppress() rparen = pp.Literal(")").suppress() + quotes = pp.Literal('"').suppress() + # accepts int or float notation, always maps to float number = pyparsing.pyparsing_common.real | pp.Word(pp.nums).set_parse_action(pp.token_map(float)) - SPACE_CHARS = ' \t\n' - - prompt_part = pp.Forward() - word = pp.Word(pp.printables).set_parse_action(lambda x: Fragment(' '.join([s for s in x]))) - word.set_name("word") - word.set_debug(False) + SPACE_CHARS = string.whitespace def make_fragment(x): #print("### making fragment for", x) if type(x) is str: return Fragment(x) elif type(x) is pp.ParseResults or type(x) is list: + #print(f'converting {x} to Fragment') return Fragment(' '.join([s for s in x])) else: raise PromptParser.ParsingException("Cannot make fragment from " + str(x)) + def parse_fragment_str(x): + return make_fragment(x) + + quoted_fragment = pp.QuotedString(quote_char='"', esc_char='\\') + quoted_fragment.set_parse_action(parse_fragment_str).set_name('quoted_fragment') + + unquoted_fragment = pp.Combine(pp.OneOrMore( + pp.Literal('\\"').set_debug(False) | + pp.Literal('\\').set_debug(False) | + pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"') + )) + unquoted_fragment.set_parse_action(parse_fragment_str).set_name('unquoted_fragment') + + parenthesized_fragment = \ + (lparen + quoted_fragment.set_debug(True) + rparen).set_name('quoted_paren_internal') | \ + (lparen + rparen).set_parse_action(lambda x: make_fragment('')) | \ + (lparen + pp.Combine(pp.OneOrMore( + pp.Literal('\\)').set_debug(False) | + pp.Literal('\\').set_debug(False) | + pp.Word(pp.printables, exclude_chars=string.whitespace + '\\)') | + pp.Word(string.whitespace) + )) + rparen).set_parse_action(parse_fragment_str).set_name('unquoted_paren_internal').set_debug(True) + parenthesized_fragment.set_name('parenthesized_fragment').set_debug(True) + # attention control of the form +(phrase) / -(phrase) / (phrase) # phrase can be multiple words, can have multiple +/- signs to increase the effect or type a floating point or integer weight attention = pp.Forward() @@ -303,41 +327,35 @@ class PromptParser(): pp.Literal('""').suppress() | (lparen + pp.Literal('""').suppress() + rparen) ).set_parse_action(lambda x: Fragment("")) + empty_string.set_name('empty_string') - original_words = ( - (lparen + pp.Literal('"').suppress() + (pp.OneOrMore(attention) | pp.CharsNotIn('"').set_parse_action(make_fragment)) + pp.Literal('"').suppress() + rparen).set_name('term1').set_debug(False) | - (pp.Literal('"').suppress() + (pp.OneOrMore(attention) | pp.CharsNotIn('"').set_parse_action(make_fragment)) + pp.Literal('"').suppress()).set_name('term2').set_debug(False) | - (lparen + (pp.OneOrMore(attention) | pp.CharsNotIn(')').set_parse_action(make_fragment)) + rparen).set_name('term3').set_debug(False) - ).set_name('original_words') - edited_words = ( - (pp.Literal('"').suppress() + (pp.OneOrMore(attention) | pp.CharsNotIn('"').set_parse_action(make_fragment)) + pp.Literal('"').suppress()).set_name('termA').set_debug(False) | - pp.Literal('""').suppress().set_parse_action(lambda x: Fragment("")) | - (pp.OneOrMore(attention) | pp.CharsNotIn(')').set_parse_action(make_fragment)).set_name('termB').set_debug(True) - ).set_name('edited_words') - cross_attention_substitute = (empty_string | original_words) + \ - pp.Literal(".swap").suppress() + \ - (empty_string | (lparen + edited_words + rparen) - ) - cross_attention_substitute.set_name('cross_attention_substitute') + original_fragment = empty_string | quoted_fragment | parenthesized_fragment | unquoted_fragment + edited_fragment = parenthesized_fragment + cross_attention_substitute = original_fragment + pp.Literal(".swap").suppress() + edited_fragment + + cross_attention_substitute.set_name('cross_attention_substitute').set_debug(True) def make_cross_attention_substitute(x): - #print("making cacs for", x) + print("making cacs for", x) cacs = CrossAttentionControlSubstitute(x[0], x[1]) - #print("made", cacs) + print("made", cacs) return cacs - cross_attention_substitute.set_parse_action(make_cross_attention_substitute) # simple fragments of text - prompt_part << (cross_attention_substitute - | attention - | word - ) + prompt_part = ( + cross_attention_substitute + | attention + | quoted_fragment + | unquoted_fragment + ) prompt_part.set_debug(False) prompt_part.set_name("prompt_part") + empty = ((lparen + rparen) | (quotes + quotes)).suppress() + # root prompt definition - prompt = pp.Group(pp.OneOrMore(prompt_part))\ + prompt = (pp.Group(pp.OneOrMore(prompt_part) | empty) + pp.StringEnd()) \ .set_parse_action(lambda x: Prompt(x[0])) # weighted blend of prompts diff --git a/tests/test_prompt_parser.py b/tests/test_prompt_parser.py index 99f4db33a1..0aa0cfd6ae 100644 --- a/tests/test_prompt_parser.py +++ b/tests/test_prompt_parser.py @@ -201,6 +201,9 @@ class PromptParserTestCase(unittest.TestCase): Fragment(',', 1), Fragment('fire', 2.0)])]) self.assertEqual(flames_to_trees_fire, parse_prompt('"0.5(fire 0.5(flames))".swap("0.7(trees) houses"), 2.0(fire)')) + def test_single(self): + print(parse_prompt('fire (trees and houses).swap("flames")')) + if __name__ == '__main__': unittest.main() From 79b4afeae7f0e2e4f922de26e6d9a458fea5c46b Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Thu, 20 Oct 2022 16:56:34 +0200 Subject: [PATCH 24/54] parser working with basic escapes --- ldm/invoke/prompt_parser.py | 48 +++++++++++++++++++++++++------------ tests/test_prompt_parser.py | 21 ++++++++++++---- 2 files changed, 50 insertions(+), 19 deletions(-) diff --git a/ldm/invoke/prompt_parser.py b/ldm/invoke/prompt_parser.py index abd9ce726c..398a596c7e 100644 --- a/ldm/invoke/prompt_parser.py +++ b/ldm/invoke/prompt_parser.py @@ -169,12 +169,16 @@ class PromptParser(): def flatten(self, root: Conjunction): + print("flattening", root) + def fuse_fragments(items): # print("fusing fragments in ", items) result = [] for x in items: - if issubclass(type(x), CrossAttentionControlledFragment): - result.append(x) + if type(x) is CrossAttentionControlSubstitute: + original_fused = fuse_fragments(x.original) + edited_fused = fuse_fragments(x.edited) + result.append(CrossAttentionControlSubstitute(original_fused, edited_fused)) else: last_weight = result[-1].weight \ if (len(result) > 0 and not issubclass(type(result[-1]), CrossAttentionControlledFragment)) \ @@ -221,10 +225,9 @@ class PromptParser(): #print(prefix + "after flattening Prompt, results is", results) else: raise PromptParser.ParsingException(f"unhandled node type {type(node)} when flattening {node}") - print(prefix + "-> after flattening", type(node), "results is", results) + print(prefix + "-> after flattening", type(node).__name__, "results is", results) return results - #print("flattening", root) flattened_parts = [] for part in root.prompts: @@ -244,6 +247,8 @@ class PromptParser(): number = pyparsing.pyparsing_common.real | pp.Word(pp.nums).set_parse_action(pp.token_map(float)) SPACE_CHARS = string.whitespace + attention = pp.Forward() + def make_fragment(x): #print("### making fragment for", x) if type(x) is str: @@ -254,33 +259,44 @@ class PromptParser(): else: raise PromptParser.ParsingException("Cannot make fragment from " + str(x)) + unquoted_fragment = pp.Forward() + quoted_fragment = pp.Forward() + parenthesized_fragment = pp.Forward() + def parse_fragment_str(x): - return make_fragment(x) + print("parsing", x) + if len(x[0].strip()) == 0: + return Fragment('') + fragment_parser = pp.Group(pp.OneOrMore(attention | pp.Word(pp.printables, exclude_chars=string.whitespace).set_parse_action(make_fragment))) + fragment_parser.set_name('word_or_attention') + result = fragment_parser.parse_string(x[0]) + #result = (pp.OneOrMore(attention | unquoted_fragment) + pp.StringEnd()).parse_string(x[0]) + print("parsed to", result) + return result - quoted_fragment = pp.QuotedString(quote_char='"', esc_char='\\') - quoted_fragment.set_parse_action(parse_fragment_str).set_name('quoted_fragment') + quoted_fragment << pp.QuotedString(quote_char='"', esc_char='\\') + quoted_fragment.set_parse_action(make_fragment).set_name('quoted_fragment') - unquoted_fragment = pp.Combine(pp.OneOrMore( + unquoted_fragment << pp.Combine(pp.OneOrMore( pp.Literal('\\"').set_debug(False) | pp.Literal('\\').set_debug(False) | pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"') )) - unquoted_fragment.set_parse_action(parse_fragment_str).set_name('unquoted_fragment') + unquoted_fragment.set_parse_action(make_fragment).set_name('unquoted_fragment') - parenthesized_fragment = \ - (lparen + quoted_fragment.set_debug(True) + rparen).set_name('quoted_paren_internal') | \ - (lparen + rparen).set_parse_action(lambda x: make_fragment('')) | \ + parenthesized_fragment << pp.Or([ + (lparen + quoted_fragment.set_parse_action(parse_fragment_str).set_debug(True) + rparen).set_name('-quoted_paren_internal').set_debug(True), + (lparen + rparen).set_parse_action(lambda x: make_fragment('')).set_name('-()').set_debug(True), (lparen + pp.Combine(pp.OneOrMore( pp.Literal('\\)').set_debug(False) | pp.Literal('\\').set_debug(False) | pp.Word(pp.printables, exclude_chars=string.whitespace + '\\)') | pp.Word(string.whitespace) - )) + rparen).set_parse_action(parse_fragment_str).set_name('unquoted_paren_internal').set_debug(True) + )).set_name('--combined').set_parse_action(parse_fragment_str).set_debug(True) + rparen)]).set_name('-unquoted_paren_internal').set_debug(True) parenthesized_fragment.set_name('parenthesized_fragment').set_debug(True) # attention control of the form +(phrase) / -(phrase) / (phrase) # phrase can be multiple words, can have multiple +/- signs to increase the effect or type a floating point or integer weight - attention = pp.Forward() attention_head = (number | pp.Word('+') | pp.Word('-'))\ .set_name("attention_head")\ .set_debug(False) @@ -352,7 +368,9 @@ class PromptParser(): prompt_part.set_debug(False) prompt_part.set_name("prompt_part") - empty = ((lparen + rparen) | (quotes + quotes)).suppress() + empty = ( + (lparen + pp.ZeroOrMore(pp.Word(string.whitespace)) + rparen) | + (quotes + pp.ZeroOrMore(pp.Word(string.whitespace)) + quotes)).set_debug(False).set_name('empty') # root prompt definition prompt = (pp.Group(pp.OneOrMore(prompt_part) | empty) + pp.StringEnd()) \ diff --git a/tests/test_prompt_parser.py b/tests/test_prompt_parser.py index 0aa0cfd6ae..38a24ca529 100644 --- a/tests/test_prompt_parser.py +++ b/tests/test_prompt_parser.py @@ -174,7 +174,7 @@ class PromptParserTestCase(unittest.TestCase): CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]), parse_prompt('a forest landscape "".swap("in winter")')) self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1), - CrossAttentionControlSubstitute([Fragment(' ',1)], [Fragment('in winter',1)])])]), + CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]), parse_prompt('a forest landscape " ".swap("in winter")')) self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1), @@ -184,7 +184,7 @@ class PromptParserTestCase(unittest.TestCase): CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]), parse_prompt('a forest landscape "in winter".swap()')) self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1), - CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment(' ',1)])])]), + CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]), parse_prompt('a forest landscape "in winter".swap(" ")')) def test_cross_attention_control_with_attention(self): @@ -201,8 +201,21 @@ class PromptParserTestCase(unittest.TestCase): Fragment(',', 1), Fragment('fire', 2.0)])]) self.assertEqual(flames_to_trees_fire, parse_prompt('"0.5(fire 0.5(flames))".swap("0.7(trees) houses"), 2.0(fire)')) - def test_single(self): - print(parse_prompt('fire (trees and houses).swap("flames")')) + + def make_basic_conjunction(self, strings: list[str]): + fragments = [Fragment(x) for x in strings] + return Conjunction([FlattenedPrompt(fragments)]) + + def make_weighted_conjunction(self, weighted_strings: list[tuple[str,float]]): + fragments = [Fragment(x, w) for x,w in weighted_strings] + return Conjunction([FlattenedPrompt(fragments)]) + + + def test_escaping(self): + self.assertEqual(self.make_basic_conjunction(['mountain \(man\)']),parse_prompt('mountain \(man\)')) + self.assertEqual(self.make_basic_conjunction(['mountain (\(man)\)']),parse_prompt('mountain (\(man)\)')) + self.assertEqual(self.make_basic_conjunction(['mountain (\(man\))']),parse_prompt('mountain (\(man\))')) + #self.assertEqual(self.make_weighted_conjunction([('mountain', 1), ('\(man\)', 1.1)]),parse_prompt('mountain +(\(man\))')) if __name__ == '__main__': From 3f13dd3ae8bbdf06067adf891e5caa925fde394a Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Thu, 20 Oct 2022 21:05:36 +0200 Subject: [PATCH 25/54] prompt parsing is now much more robust --- ldm/invoke/prompt_parser.py | 151 ++++++++++++++++++------------- tests/test_prompt_parser.py | 173 +++++++++++++++++++++++++++--------- 2 files changed, 220 insertions(+), 104 deletions(-) diff --git a/ldm/invoke/prompt_parser.py b/ldm/invoke/prompt_parser.py index 398a596c7e..d576d069aa 100644 --- a/ldm/invoke/prompt_parser.py +++ b/ldm/invoke/prompt_parser.py @@ -9,8 +9,8 @@ class Prompt(): def __init__(self, parts: list): for c in parts: - if type(c) is not Attention and not issubclass(type(c), BaseFragment): - raise PromptParser.ParsingException(f"Prompt cannot contain {type(c)}, only {BaseFragment.__subclasses__()} are allowed") + if type(c) is not Attention and not issubclass(type(c), BaseFragment) and type(c) is not pp.ParseResults: + raise PromptParser.ParsingException(f"Prompt cannot contain {type(c).__name__} {c}, only {BaseFragment.__subclasses__()} are allowed") self.children = parts def __repr__(self): return f"Prompt:{self.children}" @@ -48,6 +48,9 @@ class BaseFragment: class Fragment(BaseFragment): def __init__(self, text: str, weight: float=1): assert(type(text) is str) + if '\\"' in text or '\\(' in text or '\\)' in text: + #print("Fragment converting escaped \( \) \\\" into ( ) \"") + text = text.replace('\\(', '(').replace('\\)', ')').replace('\\"', '"') self.text = text self.weight = float(weight) @@ -152,8 +155,10 @@ class PromptParser(): def parse(self, prompt: str) -> Conjunction: ''' + This parser is *very* forgiving. If it cannot parse syntax, it will return strings as-is to be passed on to the + diffusion. :param prompt: The prompt string to parse - :return: a tuple + :return: a Conjunction representing the parsed results. ''' #print(f"!!parsing '{prompt}'") @@ -169,7 +174,7 @@ class PromptParser(): def flatten(self, root: Conjunction): - print("flattening", root) + #print("flattening", root) def fuse_fragments(items): # print("fusing fragments in ", items) @@ -196,13 +201,13 @@ class PromptParser(): #print(prefix + "flattening", node, "...") if type(node) is pp.ParseResults: for x in node: - results = flatten_internal(x, weight_scale, results, prefix+'pr') + results = flatten_internal(x, weight_scale, results, prefix+' pr ') #print(prefix, " ParseResults expanded, results is now", results) elif type(node) is Attention: # if node.weight < 1: # todo: inject a blend when flattening attention with weight <1" - for c in node.children: - results = flatten_internal(c, weight_scale * node.weight, results, prefix + ' ') + for index,c in enumerate(node.children): + results = flatten_internal(c, weight_scale * node.weight, results, prefix + f" att{index} ") elif type(node) is Fragment: results += [Fragment(node.text, node.weight*weight_scale)] elif type(node) is CrossAttentionControlSubstitute: @@ -225,7 +230,7 @@ class PromptParser(): #print(prefix + "after flattening Prompt, results is", results) else: raise PromptParser.ParsingException(f"unhandled node type {type(node)} when flattening {node}") - print(prefix + "-> after flattening", type(node).__name__, "results is", results) + #print(prefix + "-> after flattening", type(node).__name__, "results is", results) return results @@ -246,6 +251,7 @@ class PromptParser(): # accepts int or float notation, always maps to float number = pyparsing.pyparsing_common.real | pp.Word(pp.nums).set_parse_action(pp.token_map(float)) SPACE_CHARS = string.whitespace + greedy_word = pp.Word(pp.printables, exclude_chars=string.whitespace).set_name('greedy_word') attention = pp.Forward() @@ -254,7 +260,7 @@ class PromptParser(): if type(x) is str: return Fragment(x) elif type(x) is pp.ParseResults or type(x) is list: - #print(f'converting {x} to Fragment') + #print(f'converting {type(x).__name__} to Fragment') return Fragment(' '.join([s for s in x])) else: raise PromptParser.ParsingException("Cannot make fragment from " + str(x)) @@ -264,52 +270,72 @@ class PromptParser(): parenthesized_fragment = pp.Forward() def parse_fragment_str(x): - print("parsing", x) + #print("parsing fragment string", x) if len(x[0].strip()) == 0: return Fragment('') - fragment_parser = pp.Group(pp.OneOrMore(attention | pp.Word(pp.printables, exclude_chars=string.whitespace).set_parse_action(make_fragment))) + fragment_parser = pp.Group(pp.OneOrMore(attention | (greedy_word.set_parse_action(make_fragment)))) fragment_parser.set_name('word_or_attention') result = fragment_parser.parse_string(x[0]) #result = (pp.OneOrMore(attention | unquoted_fragment) + pp.StringEnd()).parse_string(x[0]) - print("parsed to", result) + #print("parsed to", result) return result quoted_fragment << pp.QuotedString(quote_char='"', esc_char='\\') - quoted_fragment.set_parse_action(make_fragment).set_name('quoted_fragment') + quoted_fragment.set_parse_action(parse_fragment_str).set_name('quoted_fragment') + + self_unescaping_escaped_quote = pp.Literal('\\"').set_parse_action(lambda x: '"') + self_unescaping_escaped_lparen = pp.Literal('\\(').set_parse_action(lambda x: '(') + self_unescaping_escaped_rparen = pp.Literal('\\)').set_parse_action(lambda x: ')') unquoted_fragment << pp.Combine(pp.OneOrMore( - pp.Literal('\\"').set_debug(False) | - pp.Literal('\\').set_debug(False) | - pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"') + self_unescaping_escaped_rparen | self_unescaping_escaped_lparen | self_unescaping_escaped_quote | + pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"()') )) unquoted_fragment.set_parse_action(make_fragment).set_name('unquoted_fragment') - parenthesized_fragment << pp.Or([ - (lparen + quoted_fragment.set_parse_action(parse_fragment_str).set_debug(True) + rparen).set_name('-quoted_paren_internal').set_debug(True), - (lparen + rparen).set_parse_action(lambda x: make_fragment('')).set_name('-()').set_debug(True), + parenthesized_fragment << pp.MatchFirst([ + (lparen + quoted_fragment.copy().set_parse_action(parse_fragment_str).set_debug(False) + rparen).set_name('-quoted_paren_internal').set_debug(False), + (lparen + rparen).set_parse_action(lambda x: make_fragment('')).set_name('-()').set_debug(False), (lparen + pp.Combine(pp.OneOrMore( - pp.Literal('\\)').set_debug(False) | - pp.Literal('\\').set_debug(False) | - pp.Word(pp.printables, exclude_chars=string.whitespace + '\\)') | + pp.Literal('\\"').set_debug(False).set_parse_action(lambda x: '"') | + pp.Literal('\\(').set_debug(False).set_parse_action(lambda x: '(') | + pp.Literal('\\)').set_debug(False).set_parse_action(lambda x: ')') | + pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"()') | pp.Word(string.whitespace) - )).set_name('--combined').set_parse_action(parse_fragment_str).set_debug(True) + rparen)]).set_name('-unquoted_paren_internal').set_debug(True) - parenthesized_fragment.set_name('parenthesized_fragment').set_debug(True) + )).set_name('--combined').set_parse_action(parse_fragment_str).set_debug(False) + rparen)]).set_name('-unquoted_paren_internal').set_debug(False) + parenthesized_fragment.set_name('parenthesized_fragment').set_debug(False) + debug_attention = False # attention control of the form +(phrase) / -(phrase) / (phrase) # phrase can be multiple words, can have multiple +/- signs to increase the effect or type a floating point or integer weight attention_head = (number | pp.Word('+') | pp.Word('-'))\ .set_name("attention_head")\ .set_debug(False) - fragment_inside_attention = pp.CharsNotIn(SPACE_CHARS+'()')\ - .set_parse_action(make_fragment)\ - .set_name("fragment_inside_attention")\ - .set_debug(False) + word_inside_attention = pp.Combine(pp.OneOrMore( + pp.Literal('\\)') | pp.Literal('\\(') | pp.Literal('\\"') | + pp.Word(pp.printables, exclude_chars=string.whitespace + '\\()"') + )).set_name('word_inside_attention') attention_with_parens = pp.Forward() - attention_with_parens_body = pp.nested_expr(content=pp.delimited_list((attention_with_parens | fragment_inside_attention), delim=SPACE_CHARS)) + attention_with_parens_delimited_list = pp.delimited_list(pp.Or([ + quoted_fragment.copy().set_debug(debug_attention), + attention.copy().set_debug(debug_attention), + word_inside_attention.set_debug(debug_attention)]).set_name('delim_inner').set_debug(debug_attention), + delim=string.whitespace) + # have to disable ignore_expr here to prevent pyparsing from stripping off quote marks + attention_with_parens_body = pp.nested_expr(content=attention_with_parens_delimited_list, + ignore_expr=None#((pp.Literal("\\(") | pp.Literal('\\)'))) + ) + attention_with_parens_body.set_debug(debug_attention) attention_with_parens << (attention_head + attention_with_parens_body) + attention_with_parens.set_name('attention_with_parens').set_debug(debug_attention) + + attention_without_parens = (pp.Word('+') | pp.Word('-')) + (quoted_fragment | word_inside_attention) + attention_without_parens.set_name('attention_without_parens').set_debug(debug_attention) + + attention << (attention_with_parens | attention_without_parens) def make_attention(x): - # print("making Attention from parsing with args", x0, x1) + #print("making Attention from", x) weight = 1 # number(str) if type(x[0]) is float or type(x[0]) is int: @@ -318,26 +344,17 @@ class PromptParser(): elif type(x[0]) is str: base = self.attention_plus_base if x[0][0] == '+' else self.attention_minus_base weight = pow(base, len(x[0])) - # print("Making attention with children of type", [str(type(x)) for x in x1]) - return Attention(weight=weight, children=x[1]) + if type(x[1]) is list or type(x[1]) is pp.ParseResults: + return Attention(weight=weight, children=[(Fragment(x) if type(x) is str else x) for x in x[1]]) + elif type(x[1]) is str: + return Attention(weight=weight, children=[Fragment(x[1])]) + elif type(x[1]) is Fragment: + return Attention(weight=weight, children=[x[1]]) + raise PromptParser.ParsingException(f"Don't know how to make attention with children {x[1]}") - attention_with_parens.set_parse_action(make_attention)\ - .set_name("attention_with_parens")\ - .set_debug(False) - - # attention control of the form ++word --word (no parens) - attention_without_parens = ( - (pp.Word('+') | pp.Word('-')) + - pp.CharsNotIn(SPACE_CHARS+'()').set_parse_action(lambda x: [[make_fragment(x)]]) - )\ - .set_name("attention_without_parens")\ - .set_debug(False) + attention_with_parens.set_parse_action(make_attention) attention_without_parens.set_parse_action(make_attention) - attention << (attention_with_parens | attention_without_parens)\ - .set_name("attention")\ - .set_debug(False) - # cross-attention control empty_string = ((lparen + rparen) | pp.Literal('""').suppress() | @@ -345,26 +362,38 @@ class PromptParser(): ).set_parse_action(lambda x: Fragment("")) empty_string.set_name('empty_string') - original_fragment = empty_string | quoted_fragment | parenthesized_fragment | unquoted_fragment + + # cross attention control + debug_cross_attention_control = False + original_fragment = pp.Or([empty_string.set_debug(debug_cross_attention_control), + quoted_fragment.set_debug(debug_cross_attention_control), + parenthesized_fragment.set_debug(debug_cross_attention_control), + unquoted_fragment.set_debug(debug_cross_attention_control)]) edited_fragment = parenthesized_fragment cross_attention_substitute = original_fragment + pp.Literal(".swap").suppress() + edited_fragment - cross_attention_substitute.set_name('cross_attention_substitute').set_debug(True) + original_fragment.set_name('original_fragment').set_debug(debug_cross_attention_control) + edited_fragment.set_name('edited_fragment').set_debug(debug_cross_attention_control) + cross_attention_substitute.set_name('cross_attention_substitute').set_debug(debug_cross_attention_control) def make_cross_attention_substitute(x): - print("making cacs for", x) + #print("making cacs for", x) cacs = CrossAttentionControlSubstitute(x[0], x[1]) - print("made", cacs) + #print("made", cacs) return cacs cross_attention_substitute.set_parse_action(make_cross_attention_substitute) + + # simple fragments of text - prompt_part = ( - cross_attention_substitute - | attention - | quoted_fragment - | unquoted_fragment - ) + # use Or to match the longest + prompt_part = pp.Or([ + cross_attention_substitute, + attention, + quoted_fragment, + unquoted_fragment, + lparen + unquoted_fragment + rparen # matches case where user has +(term) and just deletes the + + ]) prompt_part.set_debug(False) prompt_part.set_name("prompt_part") @@ -373,8 +402,10 @@ class PromptParser(): (quotes + pp.ZeroOrMore(pp.Word(string.whitespace)) + quotes)).set_debug(False).set_name('empty') # root prompt definition - prompt = (pp.Group(pp.OneOrMore(prompt_part) | empty) + pp.StringEnd()) \ - .set_parse_action(lambda x: Prompt(x[0])) + prompt = ((pp.OneOrMore(prompt_part) | empty) + pp.StringEnd()) \ + .set_parse_action(lambda x: Prompt(x)) + + # weighted blend of prompts # ("promptA", "promptB").blend(a, b) where "promptA" and "promptB" are valid prompts and a and b are float or @@ -418,7 +449,7 @@ class PromptParser(): return Conjunction(parts, weights) conjunction_with_parens_and_quotes.set_parse_action(make_conjunction) - implicit_conjunction = pp.OneOrMore(blend | prompt) + implicit_conjunction = pp.OneOrMore(blend | prompt).set_name('implicit_conjunction') implicit_conjunction.set_parse_action(lambda x: Conjunction(x)) conjunction = conjunction_with_parens_and_quotes | implicit_conjunction diff --git a/tests/test_prompt_parser.py b/tests/test_prompt_parser.py index 38a24ca529..d053253eb6 100644 --- a/tests/test_prompt_parser.py +++ b/tests/test_prompt_parser.py @@ -1,5 +1,7 @@ import unittest +import pyparsing + from ldm.invoke.prompt_parser import PromptParser, Blend, Conjunction, FlattenedPrompt, CrossAttentionControlSubstitute, \ Fragment @@ -11,39 +13,48 @@ def parse_prompt(prompt_string): #print(f"-> parsed '{prompt_string}' to {parse_result}") return parse_result +def make_basic_conjunction(strings: list[str]): + fragments = [Fragment(x) for x in strings] + return Conjunction([FlattenedPrompt(fragments)]) + +def make_weighted_conjunction(weighted_strings: list[tuple[str,float]]): + fragments = [Fragment(x, w) for x,w in weighted_strings] + return Conjunction([FlattenedPrompt(fragments)]) + + class PromptParserTestCase(unittest.TestCase): def test_empty(self): - self.assertEqual(Conjunction([FlattenedPrompt([('', 1)])]), parse_prompt('')) + self.assertEqual(make_weighted_conjunction([('', 1)]), parse_prompt('')) def test_basic(self): - self.assertEqual(Conjunction([FlattenedPrompt([('fire (flames)', 1)])]), parse_prompt("fire (flames)")) - self.assertEqual(Conjunction([FlattenedPrompt([("fire flames", 1)])]), parse_prompt("fire flames")) - self.assertEqual(Conjunction([FlattenedPrompt([("fire, flames", 1)])]), parse_prompt("fire, flames")) - self.assertEqual(Conjunction([FlattenedPrompt([("fire, flames , fire", 1)])]), parse_prompt("fire, flames , fire")) + self.assertEqual(make_weighted_conjunction([('fire flames', 1)]), parse_prompt("fire (flames)")) + self.assertEqual(make_weighted_conjunction([("fire flames", 1)]), parse_prompt("fire flames")) + self.assertEqual(make_weighted_conjunction([("fire, flames", 1)]), parse_prompt("fire, flames")) + self.assertEqual(make_weighted_conjunction([("fire, flames , fire", 1)]), parse_prompt("fire, flames , fire")) def test_attention(self): - self.assertEqual(Conjunction([FlattenedPrompt([('flames', 0.5)])]), parse_prompt("0.5(flames)")) - self.assertEqual(Conjunction([FlattenedPrompt([('fire flames', 0.5)])]), parse_prompt("0.5(fire flames)")) - self.assertEqual(Conjunction([FlattenedPrompt([('flames', 1.1)])]), parse_prompt("+(flames)")) - self.assertEqual(Conjunction([FlattenedPrompt([('flames', 0.9)])]), parse_prompt("-(flames)")) - self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1), ('flames', 0.5)])]), parse_prompt("fire 0.5(flames)")) - self.assertEqual(Conjunction([FlattenedPrompt([('flames', pow(1.1, 2))])]), parse_prompt("++(flames)")) - self.assertEqual(Conjunction([FlattenedPrompt([('flames', pow(0.9, 2))])]), parse_prompt("--(flames)")) - self.assertEqual(Conjunction([FlattenedPrompt([('flowers', pow(0.9, 3)), ('flames', pow(1.1, 3))])]), parse_prompt("---(flowers) +++flames")) - self.assertEqual(Conjunction([FlattenedPrompt([('flowers', pow(0.9, 3)), ('flames', pow(1.1, 3))])]), parse_prompt("---(flowers) +++flames")) - self.assertEqual(Conjunction([FlattenedPrompt([('flowers', pow(0.9, 3)), ('flames+', pow(1.1, 3))])]), + self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("0.5(flames)")) + self.assertEqual(make_weighted_conjunction([('fire flames', 0.5)]), parse_prompt("0.5(fire flames)")) + self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("+(flames)")) + self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("-(flames)")) + self.assertEqual(make_weighted_conjunction([('fire', 1), ('flames', 0.5)]), parse_prompt("fire 0.5(flames)")) + self.assertEqual(make_weighted_conjunction([('flames', pow(1.1, 2))]), parse_prompt("++(flames)")) + self.assertEqual(make_weighted_conjunction([('flames', pow(0.9, 2))]), parse_prompt("--(flames)")) + self.assertEqual(make_weighted_conjunction([('flowers', pow(0.9, 3)), ('flames', pow(1.1, 3))]), parse_prompt("---(flowers) +++flames")) + self.assertEqual(make_weighted_conjunction([('flowers', pow(0.9, 3)), ('flames', pow(1.1, 3))]), parse_prompt("---(flowers) +++flames")) + self.assertEqual(make_weighted_conjunction([('flowers', pow(0.9, 3)), ('flames+', pow(1.1, 3))]), parse_prompt("---(flowers) +++flames+")) - self.assertEqual(Conjunction([FlattenedPrompt([('pretty flowers', 1.1)])]), + self.assertEqual(make_weighted_conjunction([('pretty flowers', 1.1)]), parse_prompt("+(pretty flowers)")) - self.assertEqual(Conjunction([FlattenedPrompt([('pretty flowers', 1.1), (', the flames are too hot', 1)])]), + self.assertEqual(make_weighted_conjunction([('pretty flowers', 1.1), (', the flames are too hot', 1)]), parse_prompt("+(pretty flowers), the flames are too hot")) def test_no_parens_attention_runon(self): - self.assertEqual(Conjunction([FlattenedPrompt([('fire', pow(1.1, 2)), ('flames', 1.0)])]), parse_prompt("++fire flames")) - self.assertEqual(Conjunction([FlattenedPrompt([('fire', pow(0.9, 2)), ('flames', 1.0)])]), parse_prompt("--fire flames")) - self.assertEqual(Conjunction([FlattenedPrompt([('flowers', 1.0), ('fire', pow(1.1, 2)), ('flames', 1.0)])]), parse_prompt("flowers ++fire flames")) - self.assertEqual(Conjunction([FlattenedPrompt([('flowers', 1.0), ('fire', pow(0.9, 2)), ('flames', 1.0)])]), parse_prompt("flowers --fire flames")) + self.assertEqual(make_weighted_conjunction([('fire', pow(1.1, 2)), ('flames', 1.0)]), parse_prompt("++fire flames")) + self.assertEqual(make_weighted_conjunction([('fire', pow(0.9, 2)), ('flames', 1.0)]), parse_prompt("--fire flames")) + self.assertEqual(make_weighted_conjunction([('flowers', 1.0), ('fire', pow(1.1, 2)), ('flames', 1.0)]), parse_prompt("flowers ++fire flames")) + self.assertEqual(make_weighted_conjunction([('flowers', 1.0), ('fire', pow(0.9, 2)), ('flames', 1.0)]), parse_prompt("flowers --fire flames")) def test_explicit_conjunction(self): @@ -75,17 +86,27 @@ class PromptParserTestCase(unittest.TestCase): self.assertEqual(make_untouched_prompt(prompt), parse_prompt(prompt)) assert_if_prompt_string_not_untouched('a test prompt') - assert_if_prompt_string_not_untouched('a badly (formed test prompt') assert_if_prompt_string_not_untouched('a badly formed test+ prompt') - assert_if_prompt_string_not_untouched('a badly (formed test+ prompt') - assert_if_prompt_string_not_untouched('a badly (formed test+ )prompt') - assert_if_prompt_string_not_untouched('a badly (formed test+ )prompt') - assert_if_prompt_string_not_untouched('(((a badly (formed test+ )prompt') - assert_if_prompt_string_not_untouched('(a (ba)dly (f)ormed test+ prompt') - self.assertEqual(Conjunction([FlattenedPrompt([('(a (ba)dly (f)ormed test+', 1.0), ('prompt', 1.1)])]), - parse_prompt('(a (ba)dly (f)ormed test+ +prompt')) - self.assertEqual(Conjunction([Blend([FlattenedPrompt([('((a badly (formed test+', 1.0)])], weights=[1.0])]), - parse_prompt('("((a badly (formed test+ ").blend(1.0)')) + with self.assertRaises(pyparsing.ParseException): + parse_prompt('a badly (formed test prompt') + #with self.assertRaises(pyparsing.ParseException): + with self.assertRaises(pyparsing.ParseException): + parse_prompt('a badly (formed test+ prompt') + with self.assertRaises(pyparsing.ParseException): + parse_prompt('a badly (formed test+ )prompt') + with self.assertRaises(pyparsing.ParseException): + parse_prompt('a badly (formed test+ )prompt') + with self.assertRaises(pyparsing.ParseException): + parse_prompt('(((a badly (formed test+ )prompt') + with self.assertRaises(pyparsing.ParseException): + parse_prompt('(a (ba)dly (f)ormed test+ prompt') + with self.assertRaises(pyparsing.ParseException): + parse_prompt('(a (ba)dly (f)ormed test+ +prompt') + with self.assertRaises(pyparsing.ParseException): + parse_prompt('("((a badly (formed test+ ").blend(1.0)') + with self.assertRaises(pyparsing.ParseException): + parse_prompt('mountain (\\"man").swap("monkey")') + def test_blend(self): self.assertEqual(Conjunction( @@ -127,7 +148,7 @@ class PromptParserTestCase(unittest.TestCase): def test_nested(self): - self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0), ('flames', 2.0), ('trees', 3.0)])]), + self.assertEqual(make_weighted_conjunction([('fire', 1.0), ('flames', 2.0), ('trees', 3.0)]), parse_prompt('fire 2.0(flames 1.5(trees))')) self.assertEqual(Conjunction([Blend(prompts=[FlattenedPrompt([('fire', 1.0), ('flames', 1.2100000000000002)]), FlattenedPrompt([('mountain', 1.0), ('man', 2.0)])], @@ -202,20 +223,84 @@ class PromptParserTestCase(unittest.TestCase): self.assertEqual(flames_to_trees_fire, parse_prompt('"0.5(fire 0.5(flames))".swap("0.7(trees) houses"), 2.0(fire)')) - def make_basic_conjunction(self, strings: list[str]): - fragments = [Fragment(x) for x in strings] - return Conjunction([FlattenedPrompt(fragments)]) - - def make_weighted_conjunction(self, weighted_strings: list[tuple[str,float]]): - fragments = [Fragment(x, w) for x,w in weighted_strings] - return Conjunction([FlattenedPrompt(fragments)]) - def test_escaping(self): - self.assertEqual(self.make_basic_conjunction(['mountain \(man\)']),parse_prompt('mountain \(man\)')) - self.assertEqual(self.make_basic_conjunction(['mountain (\(man)\)']),parse_prompt('mountain (\(man)\)')) - self.assertEqual(self.make_basic_conjunction(['mountain (\(man\))']),parse_prompt('mountain (\(man\))')) - #self.assertEqual(self.make_weighted_conjunction([('mountain', 1), ('\(man\)', 1.1)]),parse_prompt('mountain +(\(man\))')) + + # make sure ", ( and ) can be escaped + + self.assertEqual(make_basic_conjunction(['mountain (man)']),parse_prompt('mountain \(man\)')) + self.assertEqual(make_basic_conjunction(['mountain (man )']),parse_prompt('mountain (\(man)\)')) + self.assertEqual(make_basic_conjunction(['mountain (man)']),parse_prompt('mountain (\(man\))')) + self.assertEqual(make_weighted_conjunction([('mountain', 1), ('(man)', 1.1)]), parse_prompt('mountain +(\(man\))')) + self.assertEqual(make_weighted_conjunction([('mountain', 1), ('(man)', 1.1)]), parse_prompt('"mountain" +(\(man\))')) + self.assertEqual(make_weighted_conjunction([('"mountain"', 1), ('(man)', 1.1)]), parse_prompt('\\"mountain\\" +(\(man\))')) + # same weights for each are combined into one + self.assertEqual(make_weighted_conjunction([('"mountain" (man)', 1.1)]), parse_prompt('+(\\"mountain\\") +(\(man\))')) + self.assertEqual(make_weighted_conjunction([('"mountain"', 1.1), ('(man)', 0.9)]), parse_prompt('+(\\"mountain\\") -(\(man\))')) + + self.assertEqual(make_weighted_conjunction([('mountain', 1), ('\(man\)', 1.1)]),parse_prompt('mountain 1.1(\(man\))')) + self.assertEqual(make_weighted_conjunction([('mountain', 1), ('\(man\)', 1.1)]),parse_prompt('"mountain" 1.1(\(man\))')) + self.assertEqual(make_weighted_conjunction([('"mountain"', 1), ('\(man\)', 1.1)]),parse_prompt('\\"mountain\\" 1.1(\(man\))')) + # same weights for each are combined into one + self.assertEqual(make_weighted_conjunction([('\\"mountain\\" \(man\)', 1.1)]),parse_prompt('+(\\"mountain\\") 1.1(\(man\))')) + self.assertEqual(make_weighted_conjunction([('\\"mountain\\"', 1.1), ('\(man\)', 0.9)]),parse_prompt('1.1(\\"mountain\\") 0.9(\(man\))')) + + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain', 1.1), ('\(man\)', 1.1*1.1)]),parse_prompt('hairy +(mountain +(\(man\)))')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('\(man\)', 1.1*1.1), ('mountain', 1.1)]),parse_prompt('hairy +(1.1(\(man\)) "mountain")')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain', 1.1), ('\(man\)', 1.1*1.1)]),parse_prompt('hairy +("mountain" 1.1(\(man\)) )')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, man', 1.1)]),parse_prompt('hairy +("mountain, man")')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, man with a', 1.1), ('beard', 1.1*1.1)]), parse_prompt('hairy +("mountain, man" with a +beard)')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, man with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy +("mountain, man" with a 2.0(beard))')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man\" with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy +("mountain, \\"man\\"" with a 2.0(beard))')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, m\"an\" with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy +("mountain, m\\"an\\"" with a 2.0(beard))')) + + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man (with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy +("mountain, \\\"man\" \(with a 2.0(beard))')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man w(ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy +("mountain, \\\"man\" w\(ith a 2.0(beard))')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man with( a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy +("mountain, \\\"man\" with\( a 2.0(beard))')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man )with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy +("mountain, \\\"man\" \)with a 2.0(beard))')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy +("mountain, \\\"man\" w\)ith a 2.0(beard))')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man with) a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy +("mountain, \\\"man\" with\) a 2.0(beard))')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mou)ntain, \"man (wit(h a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy +("mou\)ntain, \\\"man\" \(wit\(h a 2.0(beard))')) + self.assertEqual(make_weighted_conjunction([('hai(ry', 1), ('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hai\(ry +("mountain, \\\"man\" w\)ith a 2.0(beard))')) + self.assertEqual(make_weighted_conjunction([('hairy((', 1), ('mountain, \"man with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy\(\( +("mountain, \\\"man\" with a 2.0(beard))')) + + self.assertEqual(make_weighted_conjunction([('mountain, \"man (with a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('+("mountain, \\\"man\" \(with a 2.0(beard)) hairy')) + self.assertEqual(make_weighted_conjunction([('mountain, \"man w(ith a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('+("mountain, \\\"man\" w\(ith a 2.0(beard))hairy')) + self.assertEqual(make_weighted_conjunction([('mountain, \"man with( a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('+("mountain, \\\"man\" with\( a 2.0(beard)) hairy')) + self.assertEqual(make_weighted_conjunction([('mountain, \"man )with a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('+("mountain, \\\"man\" \)with a 2.0(beard)) hairy')) + self.assertEqual(make_weighted_conjunction([('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('+("mountain, \\\"man\" w\)ith a 2.0(beard)) hairy')) + self.assertEqual(make_weighted_conjunction([('mountain, \"man with) a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt(' +("mountain, \\\"man\" with\) a 2.0(beard)) hairy')) + self.assertEqual(make_weighted_conjunction([('mou)ntain, \"man (wit(h a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('+("mou\)ntain, \\\"man\" \(wit\(h a 2.0(beard)) hairy')) + self.assertEqual(make_weighted_conjunction([('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0), ('hai(ry', 1)]), parse_prompt('+("mountain, \\\"man\" w\)ith a 2.0(beard)) hai\(ry ')) + self.assertEqual(make_weighted_conjunction([('mountain, \"man with a', 1.1), ('beard', 1.1*2.0), ('hairy((', 1)]), parse_prompt('+("mountain, \\\"man\" with a 2.0(beard)) hairy\(\( ')) + + def test_cross_attention_escaping(self): + + self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('monkey', 1)])])]), + parse_prompt('mountain (man).swap(monkey)')) + self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('m(onkey', 1)])])]), + parse_prompt('mountain (man).swap(m\(onkey)')) + self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('m(an', 1)], [Fragment('m(onkey', 1)])])]), + parse_prompt('mountain (m\(an).swap(m\(onkey)')) + self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('(((', 1)], [Fragment('m(on))key', 1)])])]), + parse_prompt('mountain (\(\(\().swap(m\(on\)\)key)')) + + self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('monkey', 1)])])]), + parse_prompt('mountain ("man").swap(monkey)')) + self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('monkey', 1)])])]), + parse_prompt('mountain ("man").swap("monkey")')) + self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('"man', 1)], [Fragment('monkey', 1)])])]), + parse_prompt('mountain (\\"man).swap("monkey")')) + self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('m(onkey', 1)])])]), + parse_prompt('mountain (man).swap(m\(onkey)')) + self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('m(an', 1)], [Fragment('m(onkey', 1)])])]), + parse_prompt('mountain (m\(an).swap(m\(onkey)')) + self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('(((', 1)], [Fragment('m(on))key', 1)])])]), + parse_prompt('mountain (\(\(\().swap(m\(on\)\)key)')) + + def test_single(self): + self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('"man', 1)], [Fragment('monkey', 1)])])]), + parse_prompt('mountain (\\"man).swap("monkey")')) if __name__ == '__main__': From da88097abac211e9769ce51c4dd8f79d6ed64e9f Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Thu, 20 Oct 2022 21:41:32 +0200 Subject: [PATCH 26/54] fix prompt handling in conditioning.py --- ldm/invoke/conditioning.py | 8 ++++---- ldm/invoke/prompt_parser.py | 39 +++++++++++++++++++++---------------- tests/test_prompt_parser.py | 12 ++++++++++++ 3 files changed, 38 insertions(+), 21 deletions(-) diff --git a/ldm/invoke/conditioning.py b/ldm/invoke/conditioning.py index fb6d8d443e..e3685db615 100644 --- a/ldm/invoke/conditioning.py +++ b/ldm/invoke/conditioning.py @@ -66,10 +66,10 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n edited_prompt = FlattenedPrompt() for fragment in flattened_prompt.children: if type(fragment) is CrossAttentionControlSubstitute: - original_prompt.append(fragment.original_fragment) - edited_prompt.append(fragment.edited_fragment) - elif type(fragment) is CrossAttentionControlAppend: - edited_prompt.append(fragment.fragment) + original_prompt.append(fragment.original) + edited_prompt.append(fragment.edited) + #elif type(fragment) is CrossAttentionControlAppend: + # edited_prompt.append(fragment.fragment) else: # regular fragment original_prompt.append(fragment) diff --git a/ldm/invoke/prompt_parser.py b/ldm/invoke/prompt_parser.py index d576d069aa..68cc102584 100644 --- a/ldm/invoke/prompt_parser.py +++ b/ldm/invoke/prompt_parser.py @@ -1,4 +1,5 @@ import string +from typing import Union import pyparsing import pyparsing as pp @@ -17,24 +18,31 @@ class Prompt(): def __eq__(self, other): return type(other) is Prompt and other.children == self.children +class BaseFragment: + pass + class FlattenedPrompt(): - def __init__(self, parts: list): + def __init__(self, parts: list=[]): # verify type correctness - parts_converted = [] + self.children = [] for part in parts: - if issubclass(type(part), BaseFragment): - parts_converted.append(part) - elif type(part) is tuple: - # upgrade tuples to Fragments - if type(part[0]) is not str or (type(part[1]) is not float and type(part[1]) is not int): - raise PromptParser.ParsingException( - f"FlattenedPrompt cannot contain {part}, only Fragments or (str, float) tuples are allowed") - parts_converted.append(Fragment(part[0], part[1])) - else: + self.append(part) + + def append(self, fragment: Union[list, BaseFragment, tuple]): + if type(fragment) is list: + for x in fragment: + self.append(x) + elif issubclass(type(fragment), BaseFragment): + self.children.append(fragment) + elif type(fragment) is tuple: + # upgrade tuples to Fragments + if type(fragment[0]) is not str or (type(fragment[1]) is not float and type(fragment[1]) is not int): raise PromptParser.ParsingException( - f"FlattenedPrompt cannot contain {part}, only Fragments or (str, float) tuples are allowed") - # all looks good - self.children = parts_converted + f"FlattenedPrompt cannot contain {fragment}, only Fragments or (str, float) tuples are allowed") + self.children.append(Fragment(fragment[0], fragment[1])) + else: + raise PromptParser.ParsingException( + f"FlattenedPrompt cannot contain {fragment}, only Fragments or (str, float) tuples are allowed") def __repr__(self): return f"FlattenedPrompt:{self.children}" @@ -42,9 +50,6 @@ class FlattenedPrompt(): return type(other) is FlattenedPrompt and other.children == self.children # abstract base class for Fragments -class BaseFragment: - pass - class Fragment(BaseFragment): def __init__(self, text: str, weight: float=1): assert(type(text) is str) diff --git a/tests/test_prompt_parser.py b/tests/test_prompt_parser.py index d053253eb6..2bfae0cb48 100644 --- a/tests/test_prompt_parser.py +++ b/tests/test_prompt_parser.py @@ -156,6 +156,18 @@ class PromptParserTestCase(unittest.TestCase): parse_prompt('("fire ++(flames)", "mountain 2(man)").blend(1,1)')) def test_cross_attention_control(self): + + self.assertEqual(Conjunction([ + FlattenedPrompt([Fragment('a', 1), + CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)]), + Fragment('eating a hotdog', 1)])]), parse_prompt("a \"cat\".swap(dog) eating a hotdog")) + + self.assertEqual(Conjunction([ + FlattenedPrompt([Fragment('a', 1), + CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)]), + Fragment('eating a hotdog', 1)])]), parse_prompt("a cat.swap(dog) eating a hotdog")) + + fire_flames_to_trees = Conjunction([FlattenedPrompt([('fire', 1.0), \ CrossAttentionControlSubstitute([Fragment('flames', 1)], [Fragment('trees', 1)])])]) self.assertEqual(fire_flames_to_trees, parse_prompt('fire "flames".swap(trees)')) From da75876639cb94be33c3427b059737879a87636f Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Fri, 21 Oct 2022 00:08:28 +0200 Subject: [PATCH 27/54] better support for word.swap(otherWord) without parantheses or quotes --- ldm/invoke/prompt_parser.py | 21 ++++++++++++--------- tests/test_prompt_parser.py | 2 -- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/ldm/invoke/prompt_parser.py b/ldm/invoke/prompt_parser.py index 68cc102584..2110b8ba34 100644 --- a/ldm/invoke/prompt_parser.py +++ b/ldm/invoke/prompt_parser.py @@ -270,7 +270,6 @@ class PromptParser(): else: raise PromptParser.ParsingException("Cannot make fragment from " + str(x)) - unquoted_fragment = pp.Forward() quoted_fragment = pp.Forward() parenthesized_fragment = pp.Forward() @@ -281,7 +280,7 @@ class PromptParser(): fragment_parser = pp.Group(pp.OneOrMore(attention | (greedy_word.set_parse_action(make_fragment)))) fragment_parser.set_name('word_or_attention') result = fragment_parser.parse_string(x[0]) - #result = (pp.OneOrMore(attention | unquoted_fragment) + pp.StringEnd()).parse_string(x[0]) + #result = (pp.OneOrMore(attention | unquoted_word) + pp.StringEnd()).parse_string(x[0]) #print("parsed to", result) return result @@ -292,11 +291,15 @@ class PromptParser(): self_unescaping_escaped_lparen = pp.Literal('\\(').set_parse_action(lambda x: '(') self_unescaping_escaped_rparen = pp.Literal('\\)').set_parse_action(lambda x: ')') - unquoted_fragment << pp.Combine(pp.OneOrMore( - self_unescaping_escaped_rparen | self_unescaping_escaped_lparen | self_unescaping_escaped_quote | - pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"()') - )) - unquoted_fragment.set_parse_action(make_fragment).set_name('unquoted_fragment') + def not_ends_with_swap(x): + #print("trying to match:", x) + return not x[0].endswith('.swap') + + unquoted_fragment = pp.Combine(pp.OneOrMore( + self_unescaping_escaped_rparen | self_unescaping_escaped_lparen | self_unescaping_escaped_quote | + pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"()'))) + unquoted_fragment.set_parse_action(make_fragment).set_name('unquoted_fragment').set_debug(True) + #print(unquoted_fragment.parse_string("cat.swap(dog)")) parenthesized_fragment << pp.MatchFirst([ (lparen + quoted_fragment.copy().set_parse_action(parse_fragment_str).set_debug(False) + rparen).set_name('-quoted_paren_internal').set_debug(False), @@ -367,13 +370,13 @@ class PromptParser(): ).set_parse_action(lambda x: Fragment("")) empty_string.set_name('empty_string') - # cross attention control debug_cross_attention_control = False original_fragment = pp.Or([empty_string.set_debug(debug_cross_attention_control), quoted_fragment.set_debug(debug_cross_attention_control), parenthesized_fragment.set_debug(debug_cross_attention_control), - unquoted_fragment.set_debug(debug_cross_attention_control)]) + pp.Word(pp.printables, exclude_chars=string.whitespace + '.').set_parse_action(make_fragment) + pp.FollowedBy(".swap") + ]) edited_fragment = parenthesized_fragment cross_attention_substitute = original_fragment + pp.Literal(".swap").suppress() + edited_fragment diff --git a/tests/test_prompt_parser.py b/tests/test_prompt_parser.py index 2bfae0cb48..902f4b925c 100644 --- a/tests/test_prompt_parser.py +++ b/tests/test_prompt_parser.py @@ -104,8 +104,6 @@ class PromptParserTestCase(unittest.TestCase): parse_prompt('(a (ba)dly (f)ormed test+ +prompt') with self.assertRaises(pyparsing.ParseException): parse_prompt('("((a badly (formed test+ ").blend(1.0)') - with self.assertRaises(pyparsing.ParseException): - parse_prompt('mountain (\\"man").swap("monkey")') def test_blend(self): From 2e0b1c4c8b306f5b76a3049889e1f1c0d76301b0 Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Fri, 21 Oct 2022 03:29:50 +0200 Subject: [PATCH 28/54] ok now we're cooking --- ldm/invoke/prompt_parser.py | 463 ++++++++++++++++++++---------------- tests/test_prompt_parser.py | 12 +- 2 files changed, 273 insertions(+), 202 deletions(-) diff --git a/ldm/invoke/prompt_parser.py b/ldm/invoke/prompt_parser.py index 2110b8ba34..ef6a72a49b 100644 --- a/ldm/invoke/prompt_parser.py +++ b/ldm/invoke/prompt_parser.py @@ -1,13 +1,17 @@ import string from typing import Union -import pyparsing import pyparsing as pp -from pyparsing import original_text_for - class Prompt(): + """ + Mid-level structure for storing the tree-like result of parsing a prompt. A Prompt may not represent the whole of + the singular user-defined "prompt string" (although it can) - for example, if the user specifies a Blend, the objects + that are to be blended together are stored individuall as Prompt objects. + Nesting makes this object not suitable for directly tokenizing; instead call flatten() on the containing Conjunction + to produce a FlattenedPrompt. + """ def __init__(self, parts: list): for c in parts: if type(c) is not Attention and not issubclass(type(c), BaseFragment) and type(c) is not pp.ParseResults: @@ -22,13 +26,16 @@ class BaseFragment: pass class FlattenedPrompt(): + """ + A Prompt that has been passed through flatten(). Its children can be readily tokenized. + """ def __init__(self, parts: list=[]): - # verify type correctness self.children = [] for part in parts: self.append(part) def append(self, fragment: Union[list, BaseFragment, tuple]): + # verify type correctness if type(fragment) is list: for x in fragment: self.append(x) @@ -49,8 +56,11 @@ class FlattenedPrompt(): def __eq__(self, other): return type(other) is FlattenedPrompt and other.children == self.children -# abstract base class for Fragments + class Fragment(BaseFragment): + """ + A Fragment is a chunk of plain text and an optional weight. The text should be passed as-is to the CLIP tokenizer. + """ def __init__(self, text: str, weight: float=1): assert(type(text) is str) if '\\"' in text or '\\(' in text or '\\)' in text: @@ -67,6 +77,12 @@ class Fragment(BaseFragment): and other.weight == self.weight class Attention(): + """ + Nestable weight control for fragments. Each object in the children array may in turn be an Attention object; + weights should be considered to accumulate as the tree is traversed to deeper levels of nesting. + + Do not traverse directly; instead obtain a FlattenedPrompt by calling Flatten() on a top-level Conjunction object. + """ def __init__(self, weight: float, children: list): self.weight = weight self.children = children @@ -81,7 +97,28 @@ class CrossAttentionControlledFragment(BaseFragment): pass class CrossAttentionControlSubstitute(CrossAttentionControlledFragment): - def __init__(self, original: Fragment, edited: Fragment): + """ + A Cross-Attention Controlled ('prompt2prompt') fragment, for use inside a Prompt, Attention, or FlattenedPrompt. + Representing an "original" word sequence that supplies feature vectors for an initial diffusion operation, and an + "edited" word sequence, to which the attention maps produced by the "original" word sequence are applied. Intuitively, + the result should be an "edited" image that looks like the "original" image with concepts swapped. + + eg "a cat sitting on a car" (original) -> "a smiling dog sitting on a car" (edited): the edited image should look + almost exactly the same as the original, but with a smiling dog rendered in place of the cat. The + CrossAttentionControlSubstitute object representing this swap may be confined to the tokens being swapped: + CrossAttentionControlSubstitute(original=[Fragment('cat')], edited=[Fragment('dog')]) + or it may represent a larger portion of the token sequence: + CrossAttentionControlSubstitute(original=[Fragment('a cat sitting on a car')], + edited=[Fragment('a smiling dog sitting on a car')]) + + In either case expect it to be embedded in a Prompt or FlattenedPrompt: + FlattenedPrompt([ + Fragment('a'), + CrossAttentionControlSubstitute(original=[Fragment('cat')], edited=[Fragment('dog')]), + Fragment('sitting on a car') + ]) + """ + def __init__(self, original: Union[Fragment, list], edited: Union[Fragment, list]): self.original = original self.edited = edited @@ -92,6 +129,7 @@ class CrossAttentionControlSubstitute(CrossAttentionControlledFragment): and other.original == self.original \ and other.edited == self.edited + class CrossAttentionControlAppend(CrossAttentionControlledFragment): def __init__(self, fragment: Fragment): self.fragment = fragment @@ -104,6 +142,10 @@ class CrossAttentionControlAppend(CrossAttentionControlledFragment): class Conjunction(): + """ + Storage for one or more Prompts or Blends, each of which is to be separately diffused and then the results merged + by weighted sum in latent space. + """ def __init__(self, prompts: list, weights: list = None): # force everything to be a Prompt #print("making conjunction with", parts) @@ -125,6 +167,11 @@ class Conjunction(): class Blend(): + """ + Stores a Blend of multiple Prompts. To apply, build feature vectors for each of the child Prompts and then perform a + weighted blend of the feature vectors to produce a single feature vector that is effectively a lerp between the + Prompts. + """ def __init__(self, prompts: list, weights: list[float], normalize_weights: bool=True): #print("making Blend with prompts", prompts, "and weights", weights) if len(prompts) != len(weights): @@ -152,16 +199,11 @@ class PromptParser(): def __init__(self, attention_plus_base=1.1, attention_minus_base=0.9): - self.attention_plus_base = attention_plus_base - self.attention_minus_base = attention_minus_base - - self.root = self.build_parser_logic() + self.root = build_parser_syntax(attention_plus_base, attention_minus_base) def parse(self, prompt: str) -> Conjunction: ''' - This parser is *very* forgiving. If it cannot parse syntax, it will return strings as-is to be passed on to the - diffusion. :param prompt: The prompt string to parse :return: a Conjunction representing the parsed results. ''' @@ -177,7 +219,16 @@ class PromptParser(): return self.flatten(root[0]) - def flatten(self, root: Conjunction): + + def flatten(self, root: Conjunction) -> Conjunction: + """ + Flattening a Conjunction traverses all of the nested tree-like structures in each of its Prompts or Blends, + producing from each of these walks a linear sequence of Fragment or CrossAttentionControlSubstitute objects + that can be readily tokenized without the need to walk a complex tree structure. + + :param root: The Conjunction to flatten. + :return: A Conjunction containing the result of flattening each of the prompts in the passed-in root. + """ #print("flattening", root) @@ -242,226 +293,238 @@ class PromptParser(): flattened_parts = [] for part in root.prompts: flattened_parts += flatten_internal(part, 1.0, [], ' C| ') + + #print("flattened to", flattened_parts) + weights = root.weights return Conjunction(flattened_parts, weights) - def build_parser_logic(self): +def build_parser_syntax(attention_plus_base: float, attention_minus_base: float): - lparen = pp.Literal("(").suppress() - rparen = pp.Literal(")").suppress() - quotes = pp.Literal('"').suppress() + lparen = pp.Literal("(").suppress() + rparen = pp.Literal(")").suppress() + quotes = pp.Literal('"').suppress() - # accepts int or float notation, always maps to float - number = pyparsing.pyparsing_common.real | pp.Word(pp.nums).set_parse_action(pp.token_map(float)) - SPACE_CHARS = string.whitespace - greedy_word = pp.Word(pp.printables, exclude_chars=string.whitespace).set_name('greedy_word') + # accepts int or float notation, always maps to float + number = pp.pyparsing_common.real | pp.Word(pp.nums).set_parse_action(pp.token_map(float)) + greedy_word = pp.Word(pp.printables, exclude_chars=string.whitespace).set_name('greedy_word') - attention = pp.Forward() + attention = pp.Forward() + quoted_fragment = pp.Forward() + parenthesized_fragment = pp.Forward() + cross_attention_substitute = pp.Forward() + prompt_part = pp.Forward() - def make_fragment(x): - #print("### making fragment for", x) - if type(x) is str: - return Fragment(x) - elif type(x) is pp.ParseResults or type(x) is list: - #print(f'converting {type(x).__name__} to Fragment') - return Fragment(' '.join([s for s in x])) - else: - raise PromptParser.ParsingException("Cannot make fragment from " + str(x)) + def make_text_fragment(x): + #print("### making fragment for", x) + if type(x) is str: + return Fragment(x) + elif type(x) is pp.ParseResults or type(x) is list: + #print(f'converting {type(x).__name__} to Fragment') + return Fragment(' '.join([s for s in x])) + else: + raise PromptParser.ParsingException("Cannot make fragment from " + str(x)) - quoted_fragment = pp.Forward() - parenthesized_fragment = pp.Forward() + def parse_fragment_str(x, in_quotes: bool=False, in_parens: bool=False): + fragment_string = x[0] + print(f"parsing fragment string \"{fragment_string}\"") + if len(fragment_string.strip()) == 0: + return Fragment('') - def parse_fragment_str(x): - #print("parsing fragment string", x) - if len(x[0].strip()) == 0: - return Fragment('') - fragment_parser = pp.Group(pp.OneOrMore(attention | (greedy_word.set_parse_action(make_fragment)))) - fragment_parser.set_name('word_or_attention') - result = fragment_parser.parse_string(x[0]) - #result = (pp.OneOrMore(attention | unquoted_word) + pp.StringEnd()).parse_string(x[0]) - #print("parsed to", result) - return result + if in_quotes: + # escape unescaped quotes + fragment_string = fragment_string.replace('"', '\\"') - quoted_fragment << pp.QuotedString(quote_char='"', esc_char='\\') - quoted_fragment.set_parse_action(parse_fragment_str).set_name('quoted_fragment') + #fragment_parser = pp.Group(pp.OneOrMore(attention | cross_attention_substitute | (greedy_word.set_parse_action(make_text_fragment)))) + result = pp.Group(pp.MatchFirst([ + pp.OneOrMore(prompt_part | quoted_fragment), + pp.Empty().set_parse_action(make_text_fragment) + pp.StringEnd() + ])).set_name('rr').set_debug(False).parse_string(fragment_string) + #result = (pp.OneOrMore(attention | unquoted_word) + pp.StringEnd()).parse_string(x[0]) + #print("parsed to", result) + return result - self_unescaping_escaped_quote = pp.Literal('\\"').set_parse_action(lambda x: '"') - self_unescaping_escaped_lparen = pp.Literal('\\(').set_parse_action(lambda x: '(') - self_unescaping_escaped_rparen = pp.Literal('\\)').set_parse_action(lambda x: ')') + quoted_fragment << pp.QuotedString(quote_char='"', esc_char=None, esc_quote='\\"') + quoted_fragment.set_parse_action(lambda x: parse_fragment_str(x, in_quotes=True)).set_name('quoted_fragment') - def not_ends_with_swap(x): - #print("trying to match:", x) - return not x[0].endswith('.swap') + escaped_quote = pp.Literal('\\"')#.set_parse_action(lambda x: '"') + escaped_lparen = pp.Literal('\\(')#.set_parse_action(lambda x: '(') + escaped_rparen = pp.Literal('\\)')#.set_parse_action(lambda x: ')') + escaped_backslash = pp.Literal('\\\\')#.set_parse_action(lambda x: '"') - unquoted_fragment = pp.Combine(pp.OneOrMore( - self_unescaping_escaped_rparen | self_unescaping_escaped_lparen | self_unescaping_escaped_quote | - pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"()'))) - unquoted_fragment.set_parse_action(make_fragment).set_name('unquoted_fragment').set_debug(True) - #print(unquoted_fragment.parse_string("cat.swap(dog)")) + def not_ends_with_swap(x): + #print("trying to match:", x) + return not x[0].endswith('.swap') - parenthesized_fragment << pp.MatchFirst([ - (lparen + quoted_fragment.copy().set_parse_action(parse_fragment_str).set_debug(False) + rparen).set_name('-quoted_paren_internal').set_debug(False), - (lparen + rparen).set_parse_action(lambda x: make_fragment('')).set_name('-()').set_debug(False), - (lparen + pp.Combine(pp.OneOrMore( - pp.Literal('\\"').set_debug(False).set_parse_action(lambda x: '"') | - pp.Literal('\\(').set_debug(False).set_parse_action(lambda x: '(') | - pp.Literal('\\)').set_debug(False).set_parse_action(lambda x: ')') | - pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"()') | - pp.Word(string.whitespace) - )).set_name('--combined').set_parse_action(parse_fragment_str).set_debug(False) + rparen)]).set_name('-unquoted_paren_internal').set_debug(False) - parenthesized_fragment.set_name('parenthesized_fragment').set_debug(False) + unquoted_fragment = pp.Combine(pp.OneOrMore( + escaped_rparen | escaped_lparen | escaped_quote | escaped_backslash | + pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"()'))) + unquoted_fragment.set_parse_action(make_text_fragment).set_name('unquoted_fragment').set_debug(False) + #print(unquoted_fragment.parse_string("cat.swap(dog)")) - debug_attention = False - # attention control of the form +(phrase) / -(phrase) / (phrase) - # phrase can be multiple words, can have multiple +/- signs to increase the effect or type a floating point or integer weight - attention_head = (number | pp.Word('+') | pp.Word('-'))\ - .set_name("attention_head")\ - .set_debug(False) - word_inside_attention = pp.Combine(pp.OneOrMore( - pp.Literal('\\)') | pp.Literal('\\(') | pp.Literal('\\"') | - pp.Word(pp.printables, exclude_chars=string.whitespace + '\\()"') - )).set_name('word_inside_attention') - attention_with_parens = pp.Forward() - attention_with_parens_delimited_list = pp.delimited_list(pp.Or([ - quoted_fragment.copy().set_debug(debug_attention), - attention.copy().set_debug(debug_attention), - word_inside_attention.set_debug(debug_attention)]).set_name('delim_inner').set_debug(debug_attention), - delim=string.whitespace) - # have to disable ignore_expr here to prevent pyparsing from stripping off quote marks - attention_with_parens_body = pp.nested_expr(content=attention_with_parens_delimited_list, - ignore_expr=None#((pp.Literal("\\(") | pp.Literal('\\)'))) - ) - attention_with_parens_body.set_debug(debug_attention) - attention_with_parens << (attention_head + attention_with_parens_body) - attention_with_parens.set_name('attention_with_parens').set_debug(debug_attention) + parenthesized_fragment << pp.Or([ + (lparen + quoted_fragment.copy().set_parse_action(lambda x: parse_fragment_str(x, in_quotes=True)).set_debug(False) + rparen).set_name('-quoted_paren_internal').set_debug(False), + (lparen + rparen).set_parse_action(lambda x: make_text_fragment('')).set_name('-()').set_debug(False), + (lparen + pp.Combine(pp.OneOrMore( + escaped_quote | escaped_lparen | escaped_rparen | escaped_backslash | + pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"()') | + pp.Word(string.whitespace) + )).set_name('--combined').set_parse_action(lambda x: parse_fragment_str(x, in_parens=True)).set_debug(False) + rparen)]).set_name('-unquoted_paren_internal').set_debug(False) + parenthesized_fragment.set_name('parenthesized_fragment').set_debug(False) - attention_without_parens = (pp.Word('+') | pp.Word('-')) + (quoted_fragment | word_inside_attention) - attention_without_parens.set_name('attention_without_parens').set_debug(debug_attention) + debug_attention = False + # attention control of the form +(phrase) / -(phrase) / (phrase) + # phrase can be multiple words, can have multiple +/- signs to increase the effect or type a floating point or integer weight + attention_head = (number | pp.Word('+') | pp.Word('-'))\ + .set_name("attention_head")\ + .set_debug(False) + word_inside_attention = pp.Combine(pp.OneOrMore( + pp.Literal('\\)') | pp.Literal('\\(') | pp.Literal('\\"') | + pp.Word(pp.printables, exclude_chars=string.whitespace + '\\()"') + )).set_name('word_inside_attention') + attention_with_parens = pp.Forward() - attention << (attention_with_parens | attention_without_parens) + attention_with_parens_delimited_list = pp.OneOrMore(pp.Or([ + quoted_fragment.copy().set_debug(debug_attention), + attention.copy().set_debug(debug_attention), + cross_attention_substitute, + word_inside_attention.set_debug(debug_attention) + #pp.White() + ]).set_name('delim_inner').set_debug(debug_attention)) + # have to disable ignore_expr here to prevent pyparsing from stripping off quote marks + attention_with_parens_body = pp.nested_expr(content=attention_with_parens_delimited_list, + ignore_expr=None#((pp.Literal("\\(") | pp.Literal('\\)'))) + ) + attention_with_parens_body.set_debug(debug_attention) + attention_with_parens << (attention_head + attention_with_parens_body) + attention_with_parens.set_name('attention_with_parens').set_debug(debug_attention) - def make_attention(x): - #print("making Attention from", x) - weight = 1 - # number(str) - if type(x[0]) is float or type(x[0]) is int: - weight = float(x[0]) - # +(str) or -(str) or +str or -str - elif type(x[0]) is str: - base = self.attention_plus_base if x[0][0] == '+' else self.attention_minus_base - weight = pow(base, len(x[0])) - if type(x[1]) is list or type(x[1]) is pp.ParseResults: - return Attention(weight=weight, children=[(Fragment(x) if type(x) is str else x) for x in x[1]]) - elif type(x[1]) is str: - return Attention(weight=weight, children=[Fragment(x[1])]) - elif type(x[1]) is Fragment: - return Attention(weight=weight, children=[x[1]]) - raise PromptParser.ParsingException(f"Don't know how to make attention with children {x[1]}") + attention_without_parens = (pp.Word('+') | pp.Word('-')) + (quoted_fragment | word_inside_attention) + attention_without_parens.set_name('attention_without_parens').set_debug(debug_attention) - attention_with_parens.set_parse_action(make_attention) - attention_without_parens.set_parse_action(make_attention) + attention << (attention_with_parens | attention_without_parens) + attention.set_name('attention') - # cross-attention control - empty_string = ((lparen + rparen) | - pp.Literal('""').suppress() | - (lparen + pp.Literal('""').suppress() + rparen) - ).set_parse_action(lambda x: Fragment("")) - empty_string.set_name('empty_string') + def make_attention(x): + #print("making Attention from", x) + weight = 1 + # number(str) + if type(x[0]) is float or type(x[0]) is int: + weight = float(x[0]) + # +(str) or -(str) or +str or -str + elif type(x[0]) is str: + base = attention_plus_base if x[0][0] == '+' else attention_minus_base + weight = pow(base, len(x[0])) + if type(x[1]) is list or type(x[1]) is pp.ParseResults: + return Attention(weight=weight, children=[(Fragment(x) if type(x) is str else x) for x in x[1]]) + elif type(x[1]) is str: + return Attention(weight=weight, children=[Fragment(x[1])]) + elif type(x[1]) is Fragment: + return Attention(weight=weight, children=[x[1]]) + raise PromptParser.ParsingException(f"Don't know how to make attention with children {x[1]}") - # cross attention control - debug_cross_attention_control = False - original_fragment = pp.Or([empty_string.set_debug(debug_cross_attention_control), - quoted_fragment.set_debug(debug_cross_attention_control), - parenthesized_fragment.set_debug(debug_cross_attention_control), - pp.Word(pp.printables, exclude_chars=string.whitespace + '.').set_parse_action(make_fragment) + pp.FollowedBy(".swap") - ]) - edited_fragment = parenthesized_fragment - cross_attention_substitute = original_fragment + pp.Literal(".swap").suppress() + edited_fragment + attention_with_parens.set_parse_action(make_attention) + attention_without_parens.set_parse_action(make_attention) - original_fragment.set_name('original_fragment').set_debug(debug_cross_attention_control) - edited_fragment.set_name('edited_fragment').set_debug(debug_cross_attention_control) - cross_attention_substitute.set_name('cross_attention_substitute').set_debug(debug_cross_attention_control) + # cross-attention control + empty_string = ((lparen + rparen) | + pp.Literal('""').suppress() | + (lparen + pp.Literal('""').suppress() + rparen) + ).set_parse_action(lambda x: Fragment("")) + empty_string.set_name('empty_string') - def make_cross_attention_substitute(x): - #print("making cacs for", x) - cacs = CrossAttentionControlSubstitute(x[0], x[1]) - #print("made", cacs) - return cacs - cross_attention_substitute.set_parse_action(make_cross_attention_substitute) + # cross attention control + debug_cross_attention_control = False + original_fragment = pp.Or([empty_string.set_debug(debug_cross_attention_control), + quoted_fragment.set_debug(debug_cross_attention_control), + parenthesized_fragment.set_debug(debug_cross_attention_control), + pp.Word(pp.printables, exclude_chars=string.whitespace + '.').set_parse_action(make_text_fragment) + pp.FollowedBy(".swap") + ]) + edited_fragment = parenthesized_fragment + cross_attention_substitute << original_fragment + pp.Literal(".swap").suppress() + edited_fragment + + original_fragment.set_name('original_fragment').set_debug(debug_cross_attention_control) + edited_fragment.set_name('edited_fragment').set_debug(debug_cross_attention_control) + cross_attention_substitute.set_name('cross_attention_substitute').set_debug(debug_cross_attention_control) + + def make_cross_attention_substitute(x): + #print("making cacs for", x) + cacs = CrossAttentionControlSubstitute(x[0], x[1]) + #print("made", cacs) + return cacs + cross_attention_substitute.set_parse_action(make_cross_attention_substitute) + + + # simple fragments of text + # use Or to match the longest + prompt_part << pp.MatchFirst([ + cross_attention_substitute, + attention, + unquoted_fragment, + lparen + unquoted_fragment + rparen # matches case where user has +(term) and just deletes the + + ]) + prompt_part.set_debug(False) + prompt_part.set_name("prompt_part") + + empty = ( + (lparen + pp.ZeroOrMore(pp.Word(string.whitespace)) + rparen) | + (quotes + pp.ZeroOrMore(pp.Word(string.whitespace)) + quotes)).set_debug(False).set_name('empty') + + # root prompt definition + prompt = ((pp.OneOrMore(prompt_part | quoted_fragment) | empty) + pp.StringEnd()) \ + .set_parse_action(lambda x: Prompt(x)) - # simple fragments of text - # use Or to match the longest - prompt_part = pp.Or([ - cross_attention_substitute, - attention, - quoted_fragment, - unquoted_fragment, - lparen + unquoted_fragment + rparen # matches case where user has +(term) and just deletes the + - ]) - prompt_part.set_debug(False) - prompt_part.set_name("prompt_part") + # weighted blend of prompts + # ("promptA", "promptB").blend(a, b) where "promptA" and "promptB" are valid prompts and a and b are float or + # int weights. + # can specify more terms eg ("promptA", "promptB", "promptC").blend(a,b,c) - empty = ( - (lparen + pp.ZeroOrMore(pp.Word(string.whitespace)) + rparen) | - (quotes + pp.ZeroOrMore(pp.Word(string.whitespace)) + quotes)).set_debug(False).set_name('empty') + def make_prompt_from_quoted_string(x): + #print(' got quoted prompt', x) - # root prompt definition - prompt = ((pp.OneOrMore(prompt_part) | empty) + pp.StringEnd()) \ - .set_parse_action(lambda x: Prompt(x)) + x_unquoted = x[0][1:-1] + if len(x_unquoted.strip()) == 0: + # print(' b : just an empty string') + return Prompt([Fragment('')]) + # print(' b parsing ', c_unquoted) + x_parsed = prompt.parse_string(x_unquoted) + #print(" quoted prompt was parsed to", type(x_parsed),":", x_parsed) + return x_parsed[0] + + quoted_prompt = pp.dbl_quoted_string.set_parse_action(make_prompt_from_quoted_string) + quoted_prompt.set_name('quoted_prompt') + + blend_terms = pp.delimited_list(quoted_prompt).set_name('blend_terms') + blend_weights = pp.delimited_list(number).set_name('blend_weights') + blend = pp.Group(lparen + pp.Group(blend_terms) + rparen + + pp.Literal(".blend").suppress() + + lparen + pp.Group(blend_weights) + rparen).set_name('blend') + blend.set_debug(False) + blend.set_parse_action(lambda x: Blend(prompts=x[0][0], weights=x[0][1])) - # weighted blend of prompts - # ("promptA", "promptB").blend(a, b) where "promptA" and "promptB" are valid prompts and a and b are float or - # int weights. - # can specify more terms eg ("promptA", "promptB", "promptC").blend(a,b,c) + conjunction_terms = blend_terms.copy().set_name('conjunction_terms') + conjunction_weights = blend_weights.copy().set_name('conjunction_weights') + conjunction_with_parens_and_quotes = pp.Group(lparen + pp.Group(conjunction_terms) + rparen + + pp.Literal(".and").suppress() + + lparen + pp.Optional(pp.Group(conjunction_weights)) + rparen).set_name('conjunction') + def make_conjunction(x): + parts_raw = x[0][0] + weights = x[0][1] if len(x[0])>1 else [1.0]*len(parts_raw) + parts = [part for part in parts_raw] + return Conjunction(parts, weights) + conjunction_with_parens_and_quotes.set_parse_action(make_conjunction) - def make_prompt_from_quoted_string(x): - #print(' got quoted prompt', x) + implicit_conjunction = pp.OneOrMore(blend | prompt).set_name('implicit_conjunction') + implicit_conjunction.set_parse_action(lambda x: Conjunction(x)) - x_unquoted = x[0][1:-1] - if len(x_unquoted.strip()) == 0: - # print(' b : just an empty string') - return Prompt([Fragment('')]) - # print(' b parsing ', c_unquoted) - x_parsed = prompt.parse_string(x_unquoted) - #print(" quoted prompt was parsed to", type(x_parsed),":", x_parsed) - return x_parsed[0] + conjunction = conjunction_with_parens_and_quotes | implicit_conjunction + conjunction.set_debug(False) - quoted_prompt = pp.dbl_quoted_string.set_parse_action(make_prompt_from_quoted_string) - quoted_prompt.set_name('quoted_prompt') - - blend_terms = pp.delimited_list(quoted_prompt).set_name('blend_terms') - blend_weights = pp.delimited_list(number).set_name('blend_weights') - blend = pp.Group(lparen + pp.Group(blend_terms) + rparen - + pp.Literal(".blend").suppress() - + lparen + pp.Group(blend_weights) + rparen).set_name('blend') - blend.set_debug(False) - - - blend.set_parse_action(lambda x: Blend(prompts=x[0][0], weights=x[0][1])) - - conjunction_terms = blend_terms.copy().set_name('conjunction_terms') - conjunction_weights = blend_weights.copy().set_name('conjunction_weights') - conjunction_with_parens_and_quotes = pp.Group(lparen + pp.Group(conjunction_terms) + rparen - + pp.Literal(".and").suppress() - + lparen + pp.Optional(pp.Group(conjunction_weights)) + rparen).set_name('conjunction') - def make_conjunction(x): - parts_raw = x[0][0] - weights = x[0][1] if len(x[0])>1 else [1.0]*len(parts_raw) - parts = [part for part in parts_raw] - return Conjunction(parts, weights) - conjunction_with_parens_and_quotes.set_parse_action(make_conjunction) - - implicit_conjunction = pp.OneOrMore(blend | prompt).set_name('implicit_conjunction') - implicit_conjunction.set_parse_action(lambda x: Conjunction(x)) - - conjunction = conjunction_with_parens_and_quotes | implicit_conjunction - conjunction.set_debug(False) - - # top-level is a conjunction of one or more blends or prompts - return conjunction + # top-level is a conjunction of one or more blends or prompts + return conjunction diff --git a/tests/test_prompt_parser.py b/tests/test_prompt_parser.py index 902f4b925c..9f12283333 100644 --- a/tests/test_prompt_parser.py +++ b/tests/test_prompt_parser.py @@ -77,6 +77,15 @@ class PromptParserTestCase(unittest.TestCase): def test_complex_conjunction(self): self.assertEqual(Conjunction([FlattenedPrompt([("mountain man", 1.0)]), FlattenedPrompt([("a person with a hat", 1.0), ("riding a bicycle", pow(1.1,2))])], weights=[0.5, 0.5]), parse_prompt("(\"mountain man\", \"a person with a hat ++(riding a bicycle)\").and(0.5, 0.5)")) + self.assertEqual(Conjunction([FlattenedPrompt([("mountain man", 1.0)]), + FlattenedPrompt([("a person with a hat", 1.0), + ("riding a", 1.1*1.1), + CrossAttentionControlSubstitute( + [Fragment("bicycle", pow(1.1,2))], + [Fragment("skateboard", pow(1.1,2))]) + ]) + ], weights=[0.5, 0.5]), + parse_prompt("(\"mountain man\", \"a person with a hat ++(riding a bicycle.swap(skateboard))\").and(0.5, 0.5)")) def test_badly_formed(self): def make_untouched_prompt(prompt): @@ -309,8 +318,7 @@ class PromptParserTestCase(unittest.TestCase): parse_prompt('mountain (\(\(\().swap(m\(on\)\)key)')) def test_single(self): - self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('"man', 1)], [Fragment('monkey', 1)])])]), - parse_prompt('mountain (\\"man).swap("monkey")')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mou)ntain, \"man (wit(h a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy +("mou\)ntain, \\\"man\" \(wit\(h a 2.0(beard))')) if __name__ == '__main__': From 404d59b1b8704ad53dd7bf70b9ef00d1a0528f6f Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Fri, 21 Oct 2022 04:15:10 +0200 Subject: [PATCH 29/54] fix blend --- ldm/invoke/conditioning.py | 7 ++++--- ldm/invoke/prompt_parser.py | 10 ++++++---- tests/test_prompt_parser.py | 6 ++++++ 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/ldm/invoke/conditioning.py b/ldm/invoke/conditioning.py index 189af10c24..dd41869311 100644 --- a/ldm/invoke/conditioning.py +++ b/ldm/invoke/conditioning.py @@ -48,7 +48,7 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n edited_conditioning = None edit_opcodes = None - if parsed_prompt is Blend: + if type(parsed_prompt) is Blend: blend: Blend = parsed_prompt embeddings_to_blend = None for flattened_prompt in blend.prompts: @@ -60,7 +60,8 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n normalize=blend.normalize_weights) else: flattened_prompt: FlattenedPrompt = parsed_prompt - wants_cross_attention_control = any([issubclass(type(x), CrossAttentionControlledFragment) for x in flattened_prompt.children]) + wants_cross_attention_control = type(flattened_prompt) is not Blend \ + and any([issubclass(type(x), CrossAttentionControlledFragment) for x in flattened_prompt.children]) if wants_cross_attention_control: original_prompt = FlattenedPrompt() edited_prompt = FlattenedPrompt() @@ -95,7 +96,7 @@ def build_token_edit_opcodes(original_tokens, edited_tokens): def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt): if type(flattened_prompt) is not FlattenedPrompt: - raise f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead" + raise Exception(f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead") fragments = [x.text for x in flattened_prompt.children] weights = [x.weight for x in flattened_prompt.children] embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True, fragment_weights=[weights]) diff --git a/ldm/invoke/prompt_parser.py b/ldm/invoke/prompt_parser.py index 1d5fb3c04a..9a7206bf42 100644 --- a/ldm/invoke/prompt_parser.py +++ b/ldm/invoke/prompt_parser.py @@ -308,7 +308,8 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float) quotes = pp.Literal('"').suppress() # accepts int or float notation, always maps to float - number = pp.pyparsing_common.real | pp.Word(pp.nums).set_parse_action(pp.token_map(float)) + number = pp.pyparsing_common.real | \ + pp.Combine(pp.Optional("-")+pp.Word(pp.nums)).set_parse_action(pp.token_map(float)) greedy_word = pp.Word(pp.printables, exclude_chars=string.whitespace).set_name('greedy_word') attention = pp.Forward() @@ -498,12 +499,13 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float) quoted_prompt = pp.dbl_quoted_string.set_parse_action(make_prompt_from_quoted_string) quoted_prompt.set_name('quoted_prompt') - blend_terms = pp.delimited_list(quoted_prompt).set_name('blend_terms') - blend_weights = pp.delimited_list(number).set_name('blend_weights') + debug_blend=True + blend_terms = pp.delimited_list(quoted_prompt).set_name('blend_terms').set_debug(debug_blend) + blend_weights = pp.delimited_list(number).set_name('blend_weights').set_debug(debug_blend) blend = pp.Group(lparen + pp.Group(blend_terms) + rparen + pp.Literal(".blend").suppress() + lparen + pp.Group(blend_weights) + rparen).set_name('blend') - blend.set_debug(False) + blend.set_debug(debug_blend) blend.set_parse_action(lambda x: Blend(prompts=x[0][0], weights=x[0][1])) diff --git a/tests/test_prompt_parser.py b/tests/test_prompt_parser.py index 9f12283333..84971fcc52 100644 --- a/tests/test_prompt_parser.py +++ b/tests/test_prompt_parser.py @@ -153,6 +153,12 @@ class PromptParserTestCase(unittest.TestCase): parse_prompt("(\"fire\", \" , \").blend(0.7, 1)") ) + self.assertEqual( + Conjunction([Blend([FlattenedPrompt([('mountain, man, hairy', 1)]), + FlattenedPrompt([('face, teeth,', 1), ('eyes', 0.9*0.9)])], weights=[1.0,-1.0])]), + parse_prompt('("mountain, man, hairy", "face, teeth, --eyes").blend(1,-1)') + ) + def test_nested(self): self.assertEqual(make_weighted_conjunction([('fire', 1.0), ('flames', 2.0), ('trees', 3.0)]), From d9655401039630fa7ebee83c0a5b2efb95065a0b Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Fri, 21 Oct 2022 04:23:19 +0200 Subject: [PATCH 30/54] more blend fixes --- ldm/invoke/conditioning.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ldm/invoke/conditioning.py b/ldm/invoke/conditioning.py index dd41869311..6a44986f8d 100644 --- a/ldm/invoke/conditioning.py +++ b/ldm/invoke/conditioning.py @@ -52,10 +52,10 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n blend: Blend = parsed_prompt embeddings_to_blend = None for flattened_prompt in blend.prompts: - this_embedding = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt) + this_embedding, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt) embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat( (embeddings_to_blend, this_embedding)) - conditioning, _ = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0), + conditioning = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0), blend.weights, normalize=blend.normalize_weights) else: From b385fdd7dec0f9207e1d7a2e2dcd94cc32bcd67b Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Fri, 21 Oct 2022 04:34:53 +0200 Subject: [PATCH 31/54] non-normalized blend --- ldm/invoke/prompt_parser.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/ldm/invoke/prompt_parser.py b/ldm/invoke/prompt_parser.py index 9a7206bf42..39138e5364 100644 --- a/ldm/invoke/prompt_parser.py +++ b/ldm/invoke/prompt_parser.py @@ -187,7 +187,7 @@ class Blend(): self.normalize_weights = normalize_weights def __repr__(self): - return f"Blend:{self.prompts} | weights {self.weights}" + return f"Blend:{self.prompts} | weights {' ' if self.normalize_weights else '(non-normalized) '}{self.weights}" def __eq__(self, other): return other.__repr__() == self.__repr__() @@ -276,7 +276,7 @@ class PromptParser(): for prompt in node.prompts: # prompt is a list flattened_subprompts = flatten_internal(prompt, weight_scale, flattened_subprompts, prefix+'B ') - results += [Blend(prompts=flattened_subprompts, weights=node.weights)] + results += [Blend(prompts=flattened_subprompts, weights=node.weights, normalize_weights=node.normalize_weights)] elif type(node) is Prompt: #print(prefix + "about to flatten Prompt with children", node.children) flattened_prompt = [] @@ -501,14 +501,22 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float) debug_blend=True blend_terms = pp.delimited_list(quoted_prompt).set_name('blend_terms').set_debug(debug_blend) - blend_weights = pp.delimited_list(number).set_name('blend_weights').set_debug(debug_blend) + blend_weights = (pp.delimited_list(number) + pp.Optional(pp.Char(",").suppress() + "no_normalize")).set_name('blend_weights').set_debug(debug_blend) blend = pp.Group(lparen + pp.Group(blend_terms) + rparen + pp.Literal(".blend").suppress() + lparen + pp.Group(blend_weights) + rparen).set_name('blend') blend.set_debug(debug_blend) + def make_blend(x): + prompts = x[0][0] + weights = x[0][1] + normalize = True + if weights[-1] == 'no_normalize': + normalize = False + weights = weights[:-1] + return Blend(prompts=prompts, weights=weights, normalize_weights=normalize) - blend.set_parse_action(lambda x: Blend(prompts=x[0][0], weights=x[0][1])) + blend.set_parse_action(make_blend) conjunction_terms = blend_terms.copy().set_name('conjunction_terms') conjunction_weights = blend_weights.copy().set_name('conjunction_weights') From dc2f30a34ed45df886bb73e29b24a819d34b6de2 Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Fri, 21 Oct 2022 11:59:42 +0200 Subject: [PATCH 32/54] put back txt2mask import --- ldm/generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ldm/generate.py b/ldm/generate.py index 39bcc28162..45ed2e73d1 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -35,7 +35,7 @@ from ldm.invoke.devices import choose_torch_device, choose_precision from ldm.invoke.conditioning import get_uc_and_c_and_ec from ldm.invoke.model_cache import ModelCache from ldm.invoke.seamless import configure_model_padding -#from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale +from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale def fix_func(orig): if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): From 2bf9f1f0d80308a927ae717fc72f20733bc0c7a0 Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Fri, 21 Oct 2022 12:18:40 +0200 Subject: [PATCH 33/54] rename StrcuturedConditioning to ExtraConditioningInfo --- ldm/invoke/generator/img2img.py | 2 +- ldm/invoke/generator/txt2img.py | 2 +- ldm/invoke/generator/txt2img2img.py | 2 +- .../diffusion/shared_invokeai_diffusion.py | 33 +++++++++---------- 4 files changed, 18 insertions(+), 21 deletions(-) diff --git a/ldm/invoke/generator/img2img.py b/ldm/invoke/generator/img2img.py index 6fa0d0c6dd..cfe3ff99bc 100644 --- a/ldm/invoke/generator/img2img.py +++ b/ldm/invoke/generator/img2img.py @@ -34,7 +34,7 @@ class Img2Img(Generator): t_enc = int(strength * steps) uc, c, ec, edit_opcodes = conditioning - extra_conditioning_info = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes) + extra_conditioning_info = InvokeAIDiffuserComponent.ExtraConditioningInfo(edited_conditioning=ec, edit_opcodes=edit_opcodes) def make_image(x_T): # encode (scaled latent) diff --git a/ldm/invoke/generator/txt2img.py b/ldm/invoke/generator/txt2img.py index 657cccc592..7e739860c3 100644 --- a/ldm/invoke/generator/txt2img.py +++ b/ldm/invoke/generator/txt2img.py @@ -22,7 +22,7 @@ class Txt2Img(Generator): """ self.perlin = perlin uc, c, ec, edit_opcodes = conditioning - extra_conditioning_info = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes) + extra_conditioning_info = InvokeAIDiffuserComponent.ExtraConditioningInfo(edited_conditioning=ec, edit_opcodes=edit_opcodes) @torch.no_grad() def make_image(x_T): diff --git a/ldm/invoke/generator/txt2img2img.py b/ldm/invoke/generator/txt2img2img.py index 64d0468418..2d67a44346 100644 --- a/ldm/invoke/generator/txt2img2img.py +++ b/ldm/invoke/generator/txt2img2img.py @@ -24,7 +24,7 @@ class Txt2Img2Img(Generator): kwargs are 'width' and 'height' """ uc, c, ec, edit_opcodes = conditioning - extra_conditioning_info = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes) + extra_conditioning_info = InvokeAIDiffuserComponent.ExtraConditioningInfo(edited_conditioning=ec, edit_opcodes=edit_opcodes) @torch.no_grad() def make_image(x_T): diff --git a/ldm/models/diffusion/shared_invokeai_diffusion.py b/ldm/models/diffusion/shared_invokeai_diffusion.py index 0a613091d5..290925fc8c 100644 --- a/ldm/models/diffusion/shared_invokeai_diffusion.py +++ b/ldm/models/diffusion/shared_invokeai_diffusion.py @@ -6,23 +6,6 @@ import torch class InvokeAIDiffuserComponent: - - class StructuredConditioning: - def __init__(self, edited_conditioning: torch.Tensor = None, edit_opcodes: list[tuple] = None): - """ - :param edited_conditioning: if doing cross-attention control, the edited conditioning (1 x 77 x 768) - :param edit_opcodes: if doing cross-attention control, opcodes from a SequenceMatcher describing how to map original conditioning tokens to edited conditioning tokens - """ - # TODO migrate conditioning and unconditioning here, too - #self.conditioning = conditioning - #self.unconditioning = unconditioning - self.edited_conditioning = edited_conditioning - self.edit_opcodes = edit_opcodes - - @property - def wants_cross_attention_control(self): - return self.edited_conditioning is not None - ''' The aim of this component is to provide a single place for code that can be applied identically to all InvokeAI diffusion procedures. @@ -31,6 +14,20 @@ class InvokeAIDiffuserComponent: * Cross Attention Control ("prompt2prompt") ''' + + class ExtraConditioningInfo: + def __init__(self, edited_conditioning: torch.Tensor = None, edit_opcodes: list[tuple] = None): + """ + :param edited_conditioning: if doing cross-attention control, the edited conditioning (1 x 77 x 768) + :param edit_opcodes: if doing cross-attention control, opcodes from a SequenceMatcher describing how to map original conditioning tokens to edited conditioning tokens + """ + self.edited_conditioning = edited_conditioning + self.edit_opcodes = edit_opcodes + + @property + def wants_cross_attention_control(self): + return self.edited_conditioning is not None + def __init__(self, model, model_forward_callback: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]): """ :param model: the unet model to pass through to cross attention control @@ -40,7 +37,7 @@ class InvokeAIDiffuserComponent: self.model_forward_callback = model_forward_callback - def setup_cross_attention_control(self, conditioning: StructuredConditioning): + def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo): self.conditioning = conditioning CrossAttentionControl.setup_cross_attention_control(self.model, conditioning.edited_conditioning, conditioning.edit_opcodes) From e574a1574f999ceb97b98f945346a6d63260035a Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Fri, 21 Oct 2022 12:42:07 +0200 Subject: [PATCH 34/54] txt2mask.py now tracking development again --- ldm/invoke/txt2mask.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/ldm/invoke/txt2mask.py b/ldm/invoke/txt2mask.py index bc8251abde..01d93546e3 100644 --- a/ldm/invoke/txt2mask.py +++ b/ldm/invoke/txt2mask.py @@ -29,9 +29,9 @@ work fine. import torch import numpy as np -from clipseg_models.clipseg import CLIPDensePredT +from models.clipseg import CLIPDensePredT from einops import rearrange, repeat -from PIL import Image, ImageOps +from PIL import Image from torchvision import transforms CLIP_VERSION = 'ViT-B/16' @@ -50,14 +50,9 @@ class SegmentedGrayscale(object): discrete_heatmap = self.heatmap.lt(threshold).int() return self._rescale(Image.fromarray(np.uint8(discrete_heatmap*255),mode='L')) - def to_transparent(self,invert:bool=False)->Image: + def to_transparent(self)->Image: transparent_image = self.image.copy() - gs = self.to_grayscale() - # The following line looks like a bug, but isn't. - # For img2img, we want the selected regions to be transparent, - # but to_grayscale() returns the opposite. - gs = ImageOps.invert(gs) if not invert else gs - transparent_image.putalpha(gs) + transparent_image.putalpha(self.to_grayscale()) return transparent_image # unscales and uncrops the 352x352 heatmap so that it matches the image again @@ -84,7 +79,7 @@ class Txt2Mask(object): self.model.load_state_dict(torch.load(CLIPSEG_WEIGHTS, map_location=torch.device('cpu')), strict=False) @torch.no_grad() - def segment(self, image, prompt:str) -> SegmentedGrayscale: + def segment(self, image:Image, prompt:str) -> SegmentedGrayscale: ''' Given a prompt string such as "a bagel", tries to identify the object in the provided image and returns a SegmentedGrayscale object in which the brighter @@ -99,10 +94,6 @@ class Txt2Mask(object): transforms.Resize((CLIPSEG_SIZE, CLIPSEG_SIZE)), # must be multiple of 64... ]) - if type(image) is str: - image = Image.open(image).convert('RGB') - - image = ImageOps.exif_transpose(image) img = self._scale_and_crop(image) img = transform(img).unsqueeze(0) From 64051d081c320e964571fa7dd0ac63557d2cd5db Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Fri, 21 Oct 2022 15:07:11 +0200 Subject: [PATCH 35/54] cleanup --- backend/server.py | 2 +- cross_attention_loop.py | 186 -------------------------- ldm/invoke/generator/img2img.py | 3 +- ldm/invoke/txt2mask.py | 19 ++- ldm/models/diffusion/ddim.py | 3 +- ldm/modules/diffusionmodules/model.py | 10 +- 6 files changed, 22 insertions(+), 201 deletions(-) delete mode 100644 cross_attention_loop.py diff --git a/backend/server.py b/backend/server.py index f14c141e12..7b8a8a5a69 100644 --- a/backend/server.py +++ b/backend/server.py @@ -527,7 +527,7 @@ def parameters_to_generated_image_metadata(parameters): rfc_dict["sampler"] = parameters["sampler_name"] # display weighted subprompts (liable to change) - subprompts = split_weighted_subprompts(parameters["prompt"], skip_normalize=True) + subprompts = split_weighted_subprompts(parameters["prompt"]) subprompts = [{"prompt": x[0], "weight": x[1]} for x in subprompts] rfc_dict["prompt"] = subprompts diff --git a/cross_attention_loop.py b/cross_attention_loop.py deleted file mode 100644 index ed3e3b0462..0000000000 --- a/cross_attention_loop.py +++ /dev/null @@ -1,186 +0,0 @@ -import random -import traceback - -import numpy as np -import torch - -from diffusers import (LMSDiscreteScheduler) -from PIL import Image -from torch import autocast -from tqdm.auto import tqdm - -import .ldm.models.diffusion.cross_attention - - -@torch.no_grad() -def stablediffusion( - clip, - clip_tokenizer, - device, - vae, - unet, - prompt='', - prompt_edit=None, - prompt_edit_token_weights=None, - prompt_edit_tokens_start=0.0, - prompt_edit_tokens_end=1.0, - prompt_edit_spatial_start=0.0, - prompt_edit_spatial_end=1.0, - guidance_scale=7.5, - steps=50, - seed=None, - width=512, - height=512, - init_image=None, - init_image_strength=0.5, - ): - if prompt_edit_token_weights is None: - prompt_edit_token_weights = [] - # Change size to multiple of 64 to prevent size mismatches inside model - width = width - width % 64 - height = height - height % 64 - - # If seed is None, randomly select seed from 0 to 2^32-1 - if seed is None: seed = random.randrange(2**32 - 1) - generator = torch.manual_seed(seed) - - # Set inference timesteps to scheduler - scheduler = LMSDiscreteScheduler(beta_start=0.00085, - beta_end=0.012, - beta_schedule='scaled_linear', - num_train_timesteps=1000, - ) - scheduler.set_timesteps(steps) - - # Preprocess image if it exists (img2img) - if init_image is not None: - # Resize and transpose for numpy b h w c -> torch b c h w - init_image = init_image.resize((width, height), resample=Image.Resampling.LANCZOS) - init_image = np.array(init_image).astype(np.float32) / 255.0 * 2.0 - 1.0 - init_image = torch.from_numpy(init_image[np.newaxis, ...].transpose(0, 3, 1, 2)) - - # If there is alpha channel, composite alpha for white, as the diffusion - # model does not support alpha channel - if init_image.shape[1] > 3: - init_image = init_image[:, :3] * init_image[:, 3:] + (1 - init_image[:, 3:]) - - # Move image to GPU - init_image = init_image.to(device) - - # Encode image - with autocast(device): - init_latent = (vae.encode(init_image) - .latent_dist - .sample(generator=generator) - * 0.18215) - - t_start = steps - int(steps * init_image_strength) - - else: - init_latent = torch.zeros((1, unet.in_channels, height // 8, width // 8), - device=device) - t_start = 0 - - # Generate random normal noise - noise = torch.randn(init_latent.shape, generator=generator, device=device) - latent = scheduler.add_noise(init_latent, - noise, - torch.tensor([scheduler.timesteps[t_start]], device=device) - ).to(device) - - # Process clip - with autocast(device): - tokens_uncond = clip_tokenizer('', padding='max_length', - max_length=clip_tokenizer.model_max_length, - truncation=True, return_tensors='pt', - return_overflowing_tokens=True - ) - embedding_uncond = clip(tokens_uncond.input_ids.to(device)).last_hidden_state - - tokens_cond = clip_tokenizer(prompt, padding='max_length', - max_length=clip_tokenizer.model_max_length, - truncation=True, return_tensors='pt', - return_overflowing_tokens=True - ) - embedding_cond = clip(tokens_cond.input_ids.to(device)).last_hidden_state - - # Process prompt editing - if prompt_edit is not None: - tokens_cond_edit = clip_tokenizer(prompt_edit, padding='max_length', - max_length=clip_tokenizer.model_max_length, - truncation=True, return_tensors='pt', - return_overflowing_tokens=True - ) - embedding_cond_edit = clip(tokens_cond_edit.input_ids.to(device)).last_hidden_state - - c_a_c.init_attention_edit(tokens_cond, tokens_cond_edit) - - c_a_c.init_attention_func() - c_a_c.init_attention_weights(prompt_edit_token_weights) - - timesteps = scheduler.timesteps[t_start:] - - for idx, timestep in tqdm(enumerate(timesteps), total=len(timesteps)): - t_index = t_start + idx - - latent_model_input = latent - latent_model_input = scheduler.scale_model_input(latent_model_input, timestep) - - # Predict the unconditional noise residual - noise_pred_uncond = unet(latent_model_input, - timestep, - encoder_hidden_states=embedding_uncond - ).sample - - # Prepare the Cross-Attention layers - if prompt_edit is not None: - c_a_c.save_last_tokens_attention() - c_a_c.save_last_self_attention() - else: - # Use weights on non-edited prompt when edit is None - c_a_c.use_last_tokens_attention_weights() - - # Predict the conditional noise residual and save the - # cross-attention layer activations - noise_pred_cond = unet(latent_model_input, - timestep, - encoder_hidden_states=embedding_cond - ).sample - - # Edit the Cross-Attention layer activations - if prompt_edit is not None: - t_scale = timestep / scheduler.num_train_timesteps - if (t_scale >= prompt_edit_tokens_start - and t_scale <= prompt_edit_tokens_end): - c_a_c.use_last_tokens_attention() - if (t_scale >= prompt_edit_spatial_start - and t_scale <= prompt_edit_spatial_end): - c_a_c.use_last_self_attention() - - # Use weights on edited prompt - c_a_c.use_last_tokens_attention_weights() - - # Predict the edited conditional noise residual using the - # cross-attention masks - noise_pred_cond = unet(latent_model_input, - timestep, - encoder_hidden_states=embedding_cond_edit - ).sample - - # Perform guidance - noise_pred = (noise_pred_uncond + guidance_scale - * (noise_pred_cond - noise_pred_uncond)) - - latent = scheduler.step(noise_pred, - t_index, - latent - ).prev_sample - - # scale and decode the image latents with vae - latent = latent / 0.18215 - image = vae.decode(latent.to(vae.dtype)).sample - - image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).numpy() - image = (image[0] * 255).round().astype('uint8') - return Image.fromarray(image) diff --git a/ldm/invoke/generator/img2img.py b/ldm/invoke/generator/img2img.py index cfe3ff99bc..2f5e6e61d0 100644 --- a/ldm/invoke/generator/img2img.py +++ b/ldm/invoke/generator/img2img.py @@ -51,9 +51,8 @@ class Img2Img(Generator): img_callback = step_callback, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc, - init_latent = self.init_latent, + init_latent = self.init_latent, # changes how noising is performed in ksampler extra_conditioning_info = extra_conditioning_info - # changes how noising is performed in ksampler ) return self.sample_to_image(samples) diff --git a/ldm/invoke/txt2mask.py b/ldm/invoke/txt2mask.py index 01d93546e3..bc8251abde 100644 --- a/ldm/invoke/txt2mask.py +++ b/ldm/invoke/txt2mask.py @@ -29,9 +29,9 @@ work fine. import torch import numpy as np -from models.clipseg import CLIPDensePredT +from clipseg_models.clipseg import CLIPDensePredT from einops import rearrange, repeat -from PIL import Image +from PIL import Image, ImageOps from torchvision import transforms CLIP_VERSION = 'ViT-B/16' @@ -50,9 +50,14 @@ class SegmentedGrayscale(object): discrete_heatmap = self.heatmap.lt(threshold).int() return self._rescale(Image.fromarray(np.uint8(discrete_heatmap*255),mode='L')) - def to_transparent(self)->Image: + def to_transparent(self,invert:bool=False)->Image: transparent_image = self.image.copy() - transparent_image.putalpha(self.to_grayscale()) + gs = self.to_grayscale() + # The following line looks like a bug, but isn't. + # For img2img, we want the selected regions to be transparent, + # but to_grayscale() returns the opposite. + gs = ImageOps.invert(gs) if not invert else gs + transparent_image.putalpha(gs) return transparent_image # unscales and uncrops the 352x352 heatmap so that it matches the image again @@ -79,7 +84,7 @@ class Txt2Mask(object): self.model.load_state_dict(torch.load(CLIPSEG_WEIGHTS, map_location=torch.device('cpu')), strict=False) @torch.no_grad() - def segment(self, image:Image, prompt:str) -> SegmentedGrayscale: + def segment(self, image, prompt:str) -> SegmentedGrayscale: ''' Given a prompt string such as "a bagel", tries to identify the object in the provided image and returns a SegmentedGrayscale object in which the brighter @@ -94,6 +99,10 @@ class Txt2Mask(object): transforms.Resize((CLIPSEG_SIZE, CLIPSEG_SIZE)), # must be multiple of 64... ]) + if type(image) is str: + image = Image.open(image).convert('RGB') + + image = ImageOps.exif_transpose(image) img = self._scale_and_crop(image) img = transform(img).unsqueeze(0) diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py index 98219fb62e..71944a9b7e 100644 --- a/ldm/models/diffusion/ddim.py +++ b/ldm/models/diffusion/ddim.py @@ -1,5 +1,4 @@ """SAMPLING ONLY.""" -from typing import Union import torch from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent @@ -29,7 +28,7 @@ class DDIMSampler(Sampler): def p_sample( self, x, - c: Union[torch.Tensor, list], + c, t, index, repeat_noise=False, diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py index 73218d36f8..739710d006 100644 --- a/ldm/modules/diffusionmodules/model.py +++ b/ldm/modules/diffusionmodules/model.py @@ -8,7 +8,7 @@ import numpy as np from einops import rearrange from ldm.util import instantiate_from_config -#from ldm.modules.attention import LinearAttention +from ldm.modules.attention import LinearAttention import psutil @@ -151,10 +151,10 @@ class ResnetBlock(nn.Module): return x + h -#class LinAttnBlock(LinearAttention): -# """to match AttnBlock usage""" -# def __init__(self, in_channels): -# super().__init__(dim=in_channels, heads=1, dim_head=in_channels) +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) class AttnBlock(nn.Module): From ee7d4d712a9ab52a08dc5d1ce62c6b5a69f379e5 Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Sat, 22 Oct 2022 11:27:56 +0200 Subject: [PATCH 36/54] parsing CrossAttentionControlSubstitute options works --- ldm/invoke/prompt_parser.py | 42 +++++++++++++++++++++++++++++-------- tests/test_prompt_parser.py | 17 ++++++++++++++- 2 files changed, 49 insertions(+), 10 deletions(-) diff --git a/ldm/invoke/prompt_parser.py b/ldm/invoke/prompt_parser.py index 39138e5364..f5b369bc48 100644 --- a/ldm/invoke/prompt_parser.py +++ b/ldm/invoke/prompt_parser.py @@ -118,16 +118,27 @@ class CrossAttentionControlSubstitute(CrossAttentionControlledFragment): Fragment('sitting on a car') ]) """ - def __init__(self, original: Union[Fragment, list], edited: Union[Fragment, list]): + def __init__(self, original: Union[Fragment, list], edited: Union[Fragment, list], options: dict=None): self.original = original self.edited = edited + default_options = { + 's_start': 0.0, + 's_end': 1.0, + 't_start': 0.0, + 't_end': 1.0 + } + merged_options = default_options + if options is not None: + merged_options.update(options) + self.options = merged_options def __repr__(self): - return f"CrossAttentionControlSubstitute:({self.original}->{self.edited})" + return f"CrossAttentionControlSubstitute:({self.original}->{self.edited} ({self.options})" def __eq__(self, other): return type(other) is CrossAttentionControlSubstitute \ and other.original == self.original \ - and other.edited == self.edited + and other.edited == self.edited \ + and other.options == self.options class CrossAttentionControlAppend(CrossAttentionControlledFragment): @@ -239,7 +250,7 @@ class PromptParser(): if type(x) is CrossAttentionControlSubstitute: original_fused = fuse_fragments(x.original) edited_fused = fuse_fragments(x.edited) - result.append(CrossAttentionControlSubstitute(original_fused, edited_fused)) + result.append(CrossAttentionControlSubstitute(original_fused, edited_fused, options=x.options)) else: last_weight = result[-1].weight \ if (len(result) > 0 and not issubclass(type(result[-1]), CrossAttentionControlledFragment)) \ @@ -269,7 +280,7 @@ class PromptParser(): elif type(node) is CrossAttentionControlSubstitute: original = flatten_internal(node.original, weight_scale, [], prefix + ' CAo ') edited = flatten_internal(node.edited, weight_scale, [], prefix + ' CAe ') - results += [CrossAttentionControlSubstitute(original, edited)] + results += [CrossAttentionControlSubstitute(original, edited, options=node.options)] elif type(node) is Blend: flattened_subprompts = [] #print(" flattening blend with prompts", node.prompts, "weights", node.weights) @@ -306,6 +317,7 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float) lparen = pp.Literal("(").suppress() rparen = pp.Literal(")").suppress() quotes = pp.Literal('"').suppress() + comma = pp.Literal(",").suppress() # accepts int or float notation, always maps to float number = pp.pyparsing_common.real | \ @@ -443,7 +455,18 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float) parenthesized_fragment.set_debug(debug_cross_attention_control), pp.Word(pp.printables, exclude_chars=string.whitespace + '.').set_parse_action(make_text_fragment) + pp.FollowedBy(".swap") ]) - edited_fragment = parenthesized_fragment + # support keyword=number arguments + cross_attention_option_keyword = pp.Or([pp.Keyword("s_start"), pp.Keyword("s_end"), pp.Keyword("t_start"), pp.Keyword("t_end")]) + cross_attention_option = pp.Group(cross_attention_option_keyword + pp.Literal("=").suppress() + number) + edited_fragment = pp.MatchFirst([ + lparen + + (quoted_fragment | + pp.Group(pp.Word(pp.printables, exclude_chars=string.whitespace + ',').set_parse_action(make_text_fragment)) + ) + + pp.Dict(pp.OneOrMore(comma + cross_attention_option)) + + rparen, + parenthesized_fragment + ]) cross_attention_substitute << original_fragment + pp.Literal(".swap").suppress() + edited_fragment original_fragment.set_name('original_fragment').set_debug(debug_cross_attention_control) @@ -451,9 +474,10 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float) cross_attention_substitute.set_name('cross_attention_substitute').set_debug(debug_cross_attention_control) def make_cross_attention_substitute(x): - #print("making cacs for", x) - cacs = CrossAttentionControlSubstitute(x[0], x[1]) - #print("made", cacs) + print("making cacs for", x[0], "->", x[1], "with options", x.as_dict()) + #if len(x>2): + cacs = CrossAttentionControlSubstitute(x[0], x[1], options=x.as_dict()) + print("made", cacs) return cacs cross_attention_substitute.set_parse_action(make_cross_attention_substitute) diff --git a/tests/test_prompt_parser.py b/tests/test_prompt_parser.py index 84971fcc52..02644012d8 100644 --- a/tests/test_prompt_parser.py +++ b/tests/test_prompt_parser.py @@ -247,7 +247,22 @@ class PromptParserTestCase(unittest.TestCase): Fragment(',', 1), Fragment('fire', 2.0)])]) self.assertEqual(flames_to_trees_fire, parse_prompt('"0.5(fire 0.5(flames))".swap("0.7(trees) houses"), 2.0(fire)')) - + def test_cross_attention_control_options(self): + self.assertEqual(Conjunction([ + FlattenedPrompt([Fragment('a', 1), + CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)], options={'s_start':0.1}), + Fragment('eating a hotdog', 1)])]), + parse_prompt("a \"cat\".swap(dog, s_start=0.1) eating a hotdog")) + self.assertEqual(Conjunction([ + FlattenedPrompt([Fragment('a', 1), + CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)], options={'t_start':0.1}), + Fragment('eating a hotdog', 1)])]), + parse_prompt("a \"cat\".swap(dog, t_start=0.1) eating a hotdog")) + self.assertEqual(Conjunction([ + FlattenedPrompt([Fragment('a', 1), + CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)], options={'s_start': 20.0, 't_start':0.1}), + Fragment('eating a hotdog', 1)])]), + parse_prompt("a \"cat\".swap(dog, t_start=0.1, s_start=20) eating a hotdog")) def test_escaping(self): From 8273c04575dfb62653888e42125df447bbec93e6 Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Sat, 22 Oct 2022 12:15:34 +0200 Subject: [PATCH 37/54] wip implementing options in diffuse step --- ldm/invoke/conditioning.py | 38 ++++++++++++++++++++++++++++++--- ldm/modules/encoders/modules.py | 15 +++++++++++++ 2 files changed, 50 insertions(+), 3 deletions(-) diff --git a/ldm/invoke/conditioning.py b/ldm/invoke/conditioning.py index 6a44986f8d..924ea39c77 100644 --- a/ldm/invoke/conditioning.py +++ b/ldm/invoke/conditioning.py @@ -16,7 +16,8 @@ from typing import Union import torch from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \ - CrossAttentionControlledFragment, CrossAttentionControlSubstitute, CrossAttentionControlAppend + CrossAttentionControlledFragment, CrossAttentionControlSubstitute, CrossAttentionControlAppend, Fragment +from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder @@ -65,27 +66,54 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n if wants_cross_attention_control: original_prompt = FlattenedPrompt() edited_prompt = FlattenedPrompt() + # for name, a0, a1, b0, b1 in edit_opcodes: only name == 'equal' is currently parsed + original_token_count = 0 + edited_token_count = 0 + edit_opcodes = [] + edit_options = [] for fragment in flattened_prompt.children: if type(fragment) is CrossAttentionControlSubstitute: original_prompt.append(fragment.original) edited_prompt.append(fragment.edited) + + to_replace_token_count = get_tokens_length(model, fragment.original) + replacement_token_count = get_tokens_length(model, fragment.edited) + edit_opcodes.append(('replace', + original_token_count, original_token_count + to_replace_token_count, + edited_token_count, edited_token_count + replacement_token_count + )) + original_token_count += to_replace_token_count + edited_token_count += replacement_token_count + edit_options.append(fragment.options) #elif type(fragment) is CrossAttentionControlAppend: # edited_prompt.append(fragment.fragment) else: # regular fragment original_prompt.append(fragment) edited_prompt.append(fragment) + + count = get_tokens_length(model, [fragment]) + edit_opcodes.append(('equal', original_token_count, original_token_count+count, edited_token_count, edited_token_count+count)) + edit_options.append(None) + original_token_count += count + edited_token_count += count original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, original_prompt) edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, edited_prompt) conditioning = original_embeddings edited_conditioning = edited_embeddings - edit_opcodes = build_token_edit_opcodes(original_tokens, edited_tokens) + print('got edit_opcodes', edit_opcodes, 'options', edit_options) else: conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt) + unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, parsed_negative_prompt) - return (unconditioning, conditioning, edited_conditioning, edit_opcodes) + return ( + unconditioning, conditioning, edited_conditioning, edit_opcodes + #InvokeAIDiffuserComponent.ExtraConditioningInfo(edited_conditioning=edited_conditioning, + # edit_opcodes=edit_opcodes, + # edit_options=edit_options) + ) def build_token_edit_opcodes(original_tokens, edited_tokens): @@ -102,6 +130,10 @@ def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: Fl embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True, fragment_weights=[weights]) return embeddings, tokens +def get_tokens_length(model, fragments: list[Fragment]): + fragment_texts = [x.text for x in fragments] + tokens = model.cond_stage_model.get_tokens(fragment_texts, include_start_and_end_markers=False) + return sum([len(x) for x in tokens]) def split_weighted_subprompts(text, skip_normalize=False)->list: diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index 2883f24d1a..8917a27a40 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -557,6 +557,21 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder): else: return batch_z + def get_tokens(self, fragments: list[str], include_start_and_end_markers: bool = True) -> list[list[int]]: + tokens = self.tokenizer( + fragments, + truncation=True, + max_length=self.max_length, + return_overflowing_tokens=False, + padding='do_not_pad', + return_tensors=None, # just give me a list of ints + )['input_ids'] + if include_start_and_end_markers: + return tokens + else: + return [x[1:-1] for x in tokens] + + @classmethod def apply_embedding_weights(self, embeddings: torch.Tensor, per_embedding_weights: list[float], normalize:bool) -> torch.Tensor: per_embedding_weights = torch.tensor(per_embedding_weights, dtype=embeddings.dtype, device=embeddings.device) From 7d677a63b834dcff98d7db25ef6a9c0e293cd953 Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Sun, 23 Oct 2022 14:58:25 +0200 Subject: [PATCH 38/54] cross attention control options --- ldm/generate.py | 8 +- ldm/invoke/conditioning.py | 22 +- ldm/invoke/generator/img2img.py | 3 +- ldm/invoke/generator/inpaint.py | 2 +- ldm/invoke/generator/txt2img.py | 3 +- ldm/invoke/generator/txt2img2img.py | 3 +- .../diffusion/cross_attention_control.py | 236 +++++++++++++ ldm/models/diffusion/ddim.py | 9 +- ldm/models/diffusion/ksampler.py | 6 +- ldm/models/diffusion/plms.py | 11 +- ldm/models/diffusion/sampler.py | 1 + .../diffusion/shared_invokeai_diffusion.py | 313 +++--------------- 12 files changed, 318 insertions(+), 299 deletions(-) create mode 100644 ldm/models/diffusion/cross_attention_control.py diff --git a/ldm/generate.py b/ldm/generate.py index f83a732816..39f0b06759 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -400,7 +400,7 @@ class Generate: mask_image = None try: - uc, c, ec, ec_index_map = get_uc_and_c_and_ec( + uc, c, extra_conditioning_info = get_uc_and_c_and_ec( prompt, model =self.model, skip_normalize=skip_normalize, log_tokens =self.log_tokenization @@ -438,7 +438,7 @@ class Generate: sampler=self.sampler, steps=steps, cfg_scale=cfg_scale, - conditioning=(uc, c, ec, ec_index_map), + conditioning=(uc, c, extra_conditioning_info), ddim_eta=ddim_eta, image_callback=image_callback, # called after the final image is generated step_callback=step_callback, # called after each intermediate image is generated @@ -541,8 +541,8 @@ class Generate: image = Image.open(image_path) # used by multiple postfixers - # todo: cross-attention - uc, c, _, _ = get_uc_and_c_and_ec( + # todo: cross-attention control + uc, c, _ = get_uc_and_c_and_ec( prompt, model =self.model, skip_normalize=opt.skip_normalize, log_tokens =opt.log_tokenization diff --git a/ldm/invoke/conditioning.py b/ldm/invoke/conditioning.py index 924ea39c77..52d40312ac 100644 --- a/ldm/invoke/conditioning.py +++ b/ldm/invoke/conditioning.py @@ -17,6 +17,7 @@ import torch from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \ CrossAttentionControlledFragment, CrossAttentionControlSubstitute, CrossAttentionControlAppend, Fragment +from ..models.diffusion.cross_attention_control import CrossAttentionControl from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder @@ -46,8 +47,7 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n print("parsed prompt to", parsed_prompt) conditioning = None - edited_conditioning = None - edit_opcodes = None + cac_args:CrossAttentionControl.Arguments = None if type(parsed_prompt) is Blend: blend: Blend = parsed_prompt @@ -98,21 +98,31 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n original_token_count += count edited_token_count += count original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, original_prompt) + # naïvely building a single edited_embeddings like this disregards the effects of changing the absolute location of + # subsequent tokens when there is >1 edit and earlier edits change the total token count. + # eg "a cat.swap(smiling dog, s_start=0.5) eating a hotdog.swap(pizza)" - when the 'pizza' edit is active but the + # 'cat' edit is not, the 'pizza' feature vector will nevertheless be affected by the introduction of the extra + # token 'smiling' in the inactive 'cat' edit. + # todo: build multiple edited_embeddings, one for each edit, and pass just the edited fragments through to the CrossAttentionControl functions edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, edited_prompt) conditioning = original_embeddings edited_conditioning = edited_embeddings print('got edit_opcodes', edit_opcodes, 'options', edit_options) + cac_args = CrossAttentionControl.Arguments( + edited_conditioning = edited_conditioning, + edit_opcodes = edit_opcodes, + edit_options = edit_options + ) else: conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt) unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, parsed_negative_prompt) return ( - unconditioning, conditioning, edited_conditioning, edit_opcodes - #InvokeAIDiffuserComponent.ExtraConditioningInfo(edited_conditioning=edited_conditioning, - # edit_opcodes=edit_opcodes, - # edit_options=edit_options) + unconditioning, conditioning, InvokeAIDiffuserComponent.ExtraConditioningInfo( + cross_attention_control_args=cac_args + ) ) diff --git a/ldm/invoke/generator/img2img.py b/ldm/invoke/generator/img2img.py index 2f5e6e61d0..4942bcc0c3 100644 --- a/ldm/invoke/generator/img2img.py +++ b/ldm/invoke/generator/img2img.py @@ -33,8 +33,7 @@ class Img2Img(Generator): ) # move to latent space t_enc = int(strength * steps) - uc, c, ec, edit_opcodes = conditioning - extra_conditioning_info = InvokeAIDiffuserComponent.ExtraConditioningInfo(edited_conditioning=ec, edit_opcodes=edit_opcodes) + uc, c, extra_conditioning_info = conditioning def make_image(x_T): # encode (scaled latent) diff --git a/ldm/invoke/generator/inpaint.py b/ldm/invoke/generator/inpaint.py index 8f01b4ad2d..25bbc7e017 100644 --- a/ldm/invoke/generator/inpaint.py +++ b/ldm/invoke/generator/inpaint.py @@ -46,7 +46,7 @@ class Inpaint(Img2Img): t_enc = int(strength * steps) # todo: support cross-attention control - uc, c, _, _ = conditioning + uc, c, _ = conditioning print(f">> target t_enc is {t_enc} steps") diff --git a/ldm/invoke/generator/txt2img.py b/ldm/invoke/generator/txt2img.py index 7e739860c3..696cc06f78 100644 --- a/ldm/invoke/generator/txt2img.py +++ b/ldm/invoke/generator/txt2img.py @@ -21,8 +21,7 @@ class Txt2Img(Generator): kwargs are 'width' and 'height' """ self.perlin = perlin - uc, c, ec, edit_opcodes = conditioning - extra_conditioning_info = InvokeAIDiffuserComponent.ExtraConditioningInfo(edited_conditioning=ec, edit_opcodes=edit_opcodes) + uc, c, extra_conditioning_info = conditioning @torch.no_grad() def make_image(x_T): diff --git a/ldm/invoke/generator/txt2img2img.py b/ldm/invoke/generator/txt2img2img.py index 2d67a44346..5808f7bdb2 100644 --- a/ldm/invoke/generator/txt2img2img.py +++ b/ldm/invoke/generator/txt2img2img.py @@ -23,8 +23,7 @@ class Txt2Img2Img(Generator): Return value depends on the seed at the time you call it kwargs are 'width' and 'height' """ - uc, c, ec, edit_opcodes = conditioning - extra_conditioning_info = InvokeAIDiffuserComponent.ExtraConditioningInfo(edited_conditioning=ec, edit_opcodes=edit_opcodes) + uc, c, extra_conditioing_info = conditioning @torch.no_grad() def make_image(x_T): diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py new file mode 100644 index 0000000000..905803ccfa --- /dev/null +++ b/ldm/models/diffusion/cross_attention_control.py @@ -0,0 +1,236 @@ +from enum import Enum + +import torch + +# adapted from bloc97's CrossAttentionControl colab +# https://github.com/bloc97/CrossAttentionControl + +class CrossAttentionControl: + + class Arguments: + def __init__(self, edited_conditioning: torch.Tensor, edit_opcodes: list[tuple], edit_options: dict): + """ + :param edited_conditioning: if doing cross-attention control, the edited conditioning [1 x 77 x 768] + :param edit_opcodes: if doing cross-attention control, a list of difflib.SequenceMatcher-like opcodes describing how to map original conditioning tokens to edited conditioning tokens (only the 'equal' opcode is required) + :param edit_options: if doing cross-attention control, per-edit options. there should be 1 item in edit_options for each item in edit_opcodes. + """ + # todo: rewrite this to take embedding fragments rather than a single edited_conditioning vector + self.edited_conditioning = edited_conditioning + self.edit_opcodes = edit_opcodes + + if edited_conditioning is not None: + assert len(edit_opcodes) == len(edit_options), \ + "there must be 1 edit_options dict for each edit_opcodes tuple" + non_none_edit_options = [x for x in edit_options if x is not None] + assert len(non_none_edit_options)>0, "missing edit_options" + if len(non_none_edit_options)>1: + print('warning: cross-attention control options are not working properly for >1 edit') + self.edit_options = non_none_edit_options[0] + + class Context: + def __init__(self, arguments: 'CrossAttentionControl.Arguments', step_count: int): + self.arguments = arguments + self.step_count = step_count + + @classmethod + def remove_cross_attention_control(cls, model): + cls.remove_attention_function(model) + + @classmethod + def setup_cross_attention_control(cls, model, + cross_attention_control_args: Arguments + ): + """ + Inject attention parameters and functions into the passed in model to enable cross attention editing. + + :param model: The unet model to inject into. + :param cross_attention_control_args: Arugments passeed to the CrossAttentionControl implementations + :return: None + """ + + # adapted from init_attention_edit + device = cross_attention_control_args.edited_conditioning.device + + # urgh. should this be hardcoded? + max_length = 77 + # mask=1 means use base prompt attention, mask=0 means use edited prompt attention + mask = torch.zeros(max_length) + indices_target = torch.arange(max_length, dtype=torch.long) + indices = torch.zeros(max_length, dtype=torch.long) + for name, a0, a1, b0, b1 in cross_attention_control_args.edit_opcodes: + if b0 < max_length: + if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0): + # these tokens have not been edited + indices[b0:b1] = indices_target[a0:a1] + mask[b0:b1] = 1 + + for m in cls.get_attention_modules(model, cls.CrossAttentionType.SELF): + m.last_attn_slice_mask = None + m.last_attn_slice_indices = None + + for m in cls.get_attention_modules(model, cls.CrossAttentionType.TOKENS): + m.last_attn_slice_mask = mask.to(device) + m.last_attn_slice_indices = indices.to(device) + + cls.inject_attention_function(model) + + + class CrossAttentionType(Enum): + SELF = 1 + TOKENS = 2 + + @classmethod + def get_active_cross_attention_control_types_for_step(cls, context: 'CrossAttentionControl.Context', step_index:int=None)\ + -> list['CrossAttentionControl.CrossAttentionType']: + """ + Should cross-attention control be applied on the given step? + :param step_index: The step index (counts upwards from 0), or None if unknown. + :return: A list of attention types that cross-attention control should be performed for on the given step. May be []. + """ + if step_index is None: + return [cls.CrossAttentionType.SELF, cls.CrossAttentionType.TOKENS] + + opts = context.arguments.edit_options + # percent_through will never reach 1.0 (but this is intended) + percent_through = float(step_index)/float(context.step_count) + to_control = [] + if opts['s_start'] <= percent_through and percent_through < opts['s_end']: + to_control.append(cls.CrossAttentionType.SELF) + if opts['t_start'] <= percent_through and percent_through < opts['t_end']: + to_control.append(cls.CrossAttentionType.TOKENS) + return to_control + + + @classmethod + def get_attention_modules(cls, model, which: CrossAttentionType): + which_attn = "attn1" if which is cls.CrossAttentionType.SELF else "attn2" + return [module for name, module in model.named_modules() if + type(module).__name__ == "CrossAttention" and which_attn in name] + + @classmethod + def clear_requests(cls, model): + self_attention_modules = cls.get_attention_modules(model, cls.CrossAttentionType.SELF) + tokens_attention_modules = cls.get_attention_modules(model, cls.CrossAttentionType.TOKENS) + for m in self_attention_modules+tokens_attention_modules: + m.save_last_attn_slice = False + m.use_last_attn_slice = False + + @classmethod + def request_save_attention_maps(cls, model, cross_attention_type: CrossAttentionType): + modules = cls.get_attention_modules(model, cross_attention_type) + for m in modules: + # clear out the saved slice in case the outermost dim changes + m.last_attn_slice = None + m.save_last_attn_slice = True + + @classmethod + def request_apply_saved_attention_maps(cls, model, cross_attention_type: CrossAttentionType): + modules = cls.get_attention_modules(model, cross_attention_type) + for m in modules: + m.use_last_attn_slice = True + + + + @classmethod + def inject_attention_function(cls, unet): + # ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276 + + def attention_slice_wrangler(self, attention_scores, suggested_attention_slice, dim, offset, slice_size): + + #print("in wrangler with suggested_attention_slice shape", suggested_attention_slice.shape, "dim", dim) + + attn_slice = suggested_attention_slice + if dim is not None: + start = offset + end = start+slice_size + #print(f"in wrangler, sliced dim {dim} {start}-{end}, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}") + #else: + # print(f"in wrangler, whole, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}") + + + if self.use_last_attn_slice: + this_attn_slice = attn_slice + if self.last_attn_slice_mask is not None: + # indices and mask operate on dim=2, no need to slice + base_attn_slice_full = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices) + base_attn_slice_mask = self.last_attn_slice_mask + if dim is None: + base_attn_slice = base_attn_slice_full + #print("using whole base slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape) + elif dim == 0: + base_attn_slice = base_attn_slice_full[start:end] + #print("using base dim 0 slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape) + elif dim == 1: + base_attn_slice = base_attn_slice_full[:, start:end] + #print("using base dim 1 slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape) + + attn_slice = this_attn_slice * (1 - base_attn_slice_mask) + \ + base_attn_slice * base_attn_slice_mask + else: + if dim is None: + attn_slice = self.last_attn_slice + #print("took whole slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape) + elif dim == 0: + attn_slice = self.last_attn_slice[start:end] + #print("took dim 0 slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape) + elif dim == 1: + attn_slice = self.last_attn_slice[:, start:end] + #print("took dim 1 slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape) + + if self.save_last_attn_slice: + if dim is None: + self.last_attn_slice = attn_slice + elif dim == 0: + # dynamically grow last_attn_slice if needed + if self.last_attn_slice is None: + self.last_attn_slice = attn_slice + #print("no last_attn_slice: shape now", self.last_attn_slice.shape) + elif self.last_attn_slice.shape[0] == start: + self.last_attn_slice = torch.cat([self.last_attn_slice, attn_slice], dim=0) + assert(self.last_attn_slice.shape[0] == end) + #print("last_attn_slice too small, appended dim 0 shape", attn_slice.shape, ", shape now", self.last_attn_slice.shape) + else: + # no need to grow + self.last_attn_slice[start:end] = attn_slice + #print("last_attn_slice shape is fine, setting dim 0 shape", attn_slice.shape, ", shape now", self.last_attn_slice.shape) + + elif dim == 1: + # dynamically grow last_attn_slice if needed + if self.last_attn_slice is None: + self.last_attn_slice = attn_slice + elif self.last_attn_slice.shape[1] == start: + self.last_attn_slice = torch.cat([self.last_attn_slice, attn_slice], dim=1) + assert(self.last_attn_slice.shape[1] == end) + else: + # no need to grow + self.last_attn_slice[:, start:end] = attn_slice + + if self.use_last_attn_weights and self.last_attn_slice_weights is not None: + if dim is None: + weights = self.last_attn_slice_weights + elif dim == 0: + weights = self.last_attn_slice_weights[start:end] + elif dim == 1: + weights = self.last_attn_slice_weights[:, start:end] + attn_slice = attn_slice * weights + + return attn_slice + + for name, module in unet.named_modules(): + module_name = type(module).__name__ + if module_name == "CrossAttention": + module.last_attn_slice = None + module.last_attn_slice_indices = None + module.last_attn_slice_mask = None + module.use_last_attn_weights = False + module.use_last_attn_slice = False + module.save_last_attn_slice = False + module.set_attention_slice_wrangler(attention_slice_wrangler) + + @classmethod + def remove_attention_function(cls, unet): + for name, module in unet.named_modules(): + module_name = type(module).__name__ + if module_name == "CrossAttention": + module.set_attention_slice_wrangler(None) + diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py index 71944a9b7e..5b5dfaf4af 100644 --- a/ldm/models/diffusion/ddim.py +++ b/ldm/models/diffusion/ddim.py @@ -18,7 +18,7 @@ class DDIMSampler(Sampler): extra_conditioning_info = kwargs.get('extra_conditioning_info', None) if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: - self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info) + self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = t_enc) else: self.invokeai_diffuser.remove_cross_attention_control() @@ -40,6 +40,7 @@ class DDIMSampler(Sampler): corrector_kwargs=None, unconditional_guidance_scale=1.0, unconditional_conditioning=None, + step_count:int=1000, # total number of steps **kwargs, ): b, *_, device = *x.shape, x.device @@ -51,7 +52,11 @@ class DDIMSampler(Sampler): # damian0815 would like to know when/if this code path is used e_t = self.model.apply_model(x, t, c) else: - e_t = self.invokeai_diffuser.do_diffusion_step(x, t, unconditional_conditioning, c, unconditional_guidance_scale) + step_index = step_count-(index+1) + e_t = self.invokeai_diffuser.do_diffusion_step(x, t, + unconditional_conditioning, c, + unconditional_guidance_scale, + step_index=step_index) if score_corrector is not None: assert self.model.parameterization == 'eps' diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index 44e418acb1..7bf48c62e8 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -37,14 +37,14 @@ class CFGDenoiser(nn.Module): extra_conditioning_info = kwargs.get('extra_conditioning_info', None) if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: - self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info) + self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = t_enc) else: self.invokeai_diffuser.remove_cross_attention_control() - def forward(self, x, sigma, uncond, cond, cond_scale): + def forward(self, x, sigma, uncond, cond, cond_scale, step_index): - next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale) + next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale, step_index) # apply threshold if self.warmup < self.warmup_max: diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py index f58e2c3220..5b4674f28d 100644 --- a/ldm/models/diffusion/plms.py +++ b/ldm/models/diffusion/plms.py @@ -23,7 +23,7 @@ class PLMSSampler(Sampler): extra_conditioning_info = kwargs.get('extra_conditioning_info', None) if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: - self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info) + self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = t_enc) else: self.invokeai_diffuser.remove_cross_attention_control() @@ -47,6 +47,7 @@ class PLMSSampler(Sampler): unconditional_conditioning=None, old_eps=[], t_next=None, + step_count:int=1000, # total number of steps **kwargs, ): b, *_, device = *x.shape, x.device @@ -59,7 +60,13 @@ class PLMSSampler(Sampler): # damian0815 would like to know when/if this code path is used e_t = self.model.apply_model(x, t, c) else: - e_t = self.invokeai_diffuser.do_diffusion_step(x, t, unconditional_conditioning, c, unconditional_guidance_scale) + # step_index is expected to count up while index counts down + step_index = step_count-(index+1) + # note that step_index == 0 is evaluated twice with different x + e_t = self.invokeai_diffuser.do_diffusion_step(x, t, + unconditional_conditioning, c, + unconditional_guidance_scale, + step_index=step_index) if score_corrector is not None: assert self.model.parameterization == 'eps' diff --git a/ldm/models/diffusion/sampler.py b/ldm/models/diffusion/sampler.py index e33d57fe31..8099997bb3 100644 --- a/ldm/models/diffusion/sampler.py +++ b/ldm/models/diffusion/sampler.py @@ -278,6 +278,7 @@ class Sampler(object): unconditional_conditioning=unconditional_conditioning, old_eps=old_eps, t_next=ts_next, + step_count=steps ) img, pred_x0, e_t = outs diff --git a/ldm/models/diffusion/shared_invokeai_diffusion.py b/ldm/models/diffusion/shared_invokeai_diffusion.py index 290925fc8c..507feacaa9 100644 --- a/ldm/models/diffusion/shared_invokeai_diffusion.py +++ b/ldm/models/diffusion/shared_invokeai_diffusion.py @@ -1,9 +1,11 @@ from enum import Enum from math import ceil -from typing import Callable +from typing import Callable, Optional import torch +from ldm.models.diffusion.cross_attention_control import CrossAttentionControl + class InvokeAIDiffuserComponent: ''' @@ -16,19 +18,16 @@ class InvokeAIDiffuserComponent: class ExtraConditioningInfo: - def __init__(self, edited_conditioning: torch.Tensor = None, edit_opcodes: list[tuple] = None): - """ - :param edited_conditioning: if doing cross-attention control, the edited conditioning (1 x 77 x 768) - :param edit_opcodes: if doing cross-attention control, opcodes from a SequenceMatcher describing how to map original conditioning tokens to edited conditioning tokens - """ - self.edited_conditioning = edited_conditioning - self.edit_opcodes = edit_opcodes + def __init__(self, cross_attention_control_args: Optional[CrossAttentionControl.Arguments]): + self.cross_attention_control_args = cross_attention_control_args @property def wants_cross_attention_control(self): - return self.edited_conditioning is not None + return self.cross_attention_control_args is not None - def __init__(self, model, model_forward_callback: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]): + def __init__(self, model, model_forward_callback: + Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor] + ): """ :param model: the unet model to pass through to cross attention control :param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning) @@ -37,44 +36,53 @@ class InvokeAIDiffuserComponent: self.model_forward_callback = model_forward_callback - def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo): + def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int): self.conditioning = conditioning - CrossAttentionControl.setup_cross_attention_control(self.model, conditioning.edited_conditioning, conditioning.edit_opcodes) + self.cross_attention_control_context = CrossAttentionControl.Context( + arguments=self.conditioning.cross_attention_control_args, + step_count=step_count + ) + CrossAttentionControl.setup_cross_attention_control(self.model, + cross_attention_control_args=self.conditioning.cross_attention_control_args + ) + #todo: refactor edited_conditioning, edit_opcodes, edit_options into a struct + #todo: apply edit_options using step_count + def remove_cross_attention_control(self): self.conditioning = None + self.cross_attention_control_context = None CrossAttentionControl.remove_cross_attention_control(self.model) - @property - def edited_conditioning(self): - if self.conditioning is None: - return None - else: - return self.conditioning.edited_conditioning - def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor, unconditioning: torch.Tensor, conditioning: torch.Tensor, - unconditional_guidance_scale: float): + unconditional_guidance_scale: float, + step_index: int=None): """ :param x: Current latents :param sigma: aka t, passed to the internal model to control how much denoising will occur :param unconditioning: [B x 77 x 768] embeddings for unconditioned output :param conditioning: [B x 77 x 768] embeddings for conditioned output :param unconditional_guidance_scale: aka CFG scale, controls how much effect the conditioning tensor has - :param model: the unet model to pass through to cross attention control - :param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning) - :return: the new latents after applying the model to x using unconditioning and CFG-scaled conditioning. + :param step_index: Counts upwards from 0 to (step_count-1) (as passed to setup_cross_attention_control, if using). May be called multiple times for a single step, therefore do not assume that its value will monotically increase. + :return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning. """ CrossAttentionControl.clear_requests(self.model) + cross_attention_control_types_to_do = [] - if self.edited_conditioning is None: + if self.cross_attention_control_context is not None: + cross_attention_control_types_to_do = CrossAttentionControl.get_active_cross_attention_control_types_for_step(self.cross_attention_control_context, step_index) + + if len(cross_attention_control_types_to_do)==0: + print('step', step_index, ': not doing cross attention control') # faster batched path x_twice = torch.cat([x]*2) sigma_twice = torch.cat([sigma]*2) both_conditionings = torch.cat([unconditioning, conditioning]) unconditioned_next_x, conditioned_next_x = self.model_forward_callback(x_twice, sigma_twice, both_conditionings).chunk(2) else: + print('step', step_index, ': doing cross attention control on', cross_attention_control_types_to_do) # slower non-batched path (20% slower on mac MPS) # We are only interested in using attention maps for conditioned_next_x, but batching them with generation of # unconditioned_next_x causes attention maps to *also* be saved for the unconditioned_next_x. @@ -86,13 +94,16 @@ class InvokeAIDiffuserComponent: unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning) # process x using the original prompt, saving the attention maps - CrossAttentionControl.request_save_attention_maps(self.model) + for type in cross_attention_control_types_to_do: + CrossAttentionControl.request_save_attention_maps(self.model, type) _ = self.model_forward_callback(x, sigma, conditioning) CrossAttentionControl.clear_requests(self.model) # process x again, using the saved attention maps to control where self.edited_conditioning will be applied - CrossAttentionControl.request_apply_saved_attention_maps(self.model) - conditioned_next_x = self.model_forward_callback(x, sigma, self.edited_conditioning) + for type in cross_attention_control_types_to_do: + CrossAttentionControl.request_apply_saved_attention_maps(self.model, type) + edited_conditioning = self.conditioning.cross_attention_control_args.edited_conditioning + conditioned_next_x = self.model_forward_callback(x, sigma, edited_conditioning) CrossAttentionControl.clear_requests(self.model) @@ -102,7 +113,6 @@ class InvokeAIDiffuserComponent: return combined_next_x - # todo: make this work @classmethod def apply_conjunction(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale): @@ -153,250 +163,3 @@ class InvokeAIDiffuserComponent: return uncond_latents + deltas_merged * global_guidance_scale - -# adapted from bloc97's CrossAttentionControl colab -# https://github.com/bloc97/CrossAttentionControl - -class CrossAttentionControl: - - - @classmethod - def remove_cross_attention_control(cls, model): - cls.remove_attention_function(model) - - @classmethod - def setup_cross_attention_control(cls, model, - substitute_conditioning: torch.Tensor, - edit_opcodes: list): - """ - Inject attention parameters and functions into the passed in model to enable cross attention editing. - - :param model: The unet model to inject into. - :param substitute_conditioning: The "edited" conditioning vector, [Bx77x768] - :param edit_opcodes: Opcodes from difflib.SequenceMatcher describing how the base - conditionings map to the "edited" conditionings. - :return: - """ - - # adapted from init_attention_edit - device = substitute_conditioning.device - - # urgh. should this be hardcoded? - max_length = 77 - # mask=1 means use base prompt attention, mask=0 means use edited prompt attention - mask = torch.zeros(max_length) - indices_target = torch.arange(max_length, dtype=torch.long) - indices = torch.zeros(max_length, dtype=torch.long) - for name, a0, a1, b0, b1 in edit_opcodes: - if b0 < max_length: - if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0): - # these tokens have not been edited - indices[b0:b1] = indices_target[a0:a1] - mask[b0:b1] = 1 - - for m in cls.get_attention_modules(model, cls.AttentionType.SELF): - m.last_attn_slice_mask = None - m.last_attn_slice_indices = None - - for m in cls.get_attention_modules(model, cls.AttentionType.TOKENS): - m.last_attn_slice_mask = mask.to(device) - m.last_attn_slice_indices = indices.to(device) - - cls.inject_attention_function(model) - - - class AttentionType(Enum): - SELF = 1 - TOKENS = 2 - - - @classmethod - def get_attention_modules(cls, model, which: AttentionType): - which_attn = "attn1" if which is cls.AttentionType.SELF else "attn2" - return [module for name, module in model.named_modules() if - type(module).__name__ == "CrossAttention" and which_attn in name] - - @classmethod - def clear_requests(cls, model): - self_attention_modules = cls.get_attention_modules(model, cls.AttentionType.SELF) - tokens_attention_modules = cls.get_attention_modules(model, cls.AttentionType.TOKENS) - for m in self_attention_modules+tokens_attention_modules: - m.save_last_attn_slice = False - m.use_last_attn_slice = False - - @classmethod - def request_save_attention_maps(cls, model): - self_attention_modules = cls.get_attention_modules(model, cls.AttentionType.SELF) - tokens_attention_modules = cls.get_attention_modules(model, cls.AttentionType.TOKENS) - for m in self_attention_modules+tokens_attention_modules: - # clear out the saved slice in case the outermost dim changes - m.last_attn_slice = None - m.save_last_attn_slice = True - - @classmethod - def request_apply_saved_attention_maps(cls, model): - self_attention_modules = cls.get_attention_modules(model, cls.AttentionType.SELF) - tokens_attention_modules = cls.get_attention_modules(model, cls.AttentionType.TOKENS) - for m in self_attention_modules+tokens_attention_modules: - m.use_last_attn_slice = True - - - - @classmethod - def inject_attention_function(cls, unet): - # ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276 - - def attention_slice_wrangler(self, attention_scores, suggested_attention_slice, dim, offset, slice_size): - - #print("in wrangler with suggested_attention_slice shape", suggested_attention_slice.shape, "dim", dim) - - attn_slice = suggested_attention_slice - if dim is not None: - start = offset - end = start+slice_size - #print(f"in wrangler, sliced dim {dim} {start}-{end}, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}") - #else: - # print(f"in wrangler, whole, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}") - - - if self.use_last_attn_slice: - this_attn_slice = attn_slice - if self.last_attn_slice_mask is not None: - # indices and mask operate on dim=2, no need to slice - base_attn_slice_full = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices) - base_attn_slice_mask = self.last_attn_slice_mask - if dim is None: - base_attn_slice = base_attn_slice_full - #print("using whole base slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape) - elif dim == 0: - base_attn_slice = base_attn_slice_full[start:end] - #print("using base dim 0 slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape) - elif dim == 1: - base_attn_slice = base_attn_slice_full[:, start:end] - #print("using base dim 1 slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape) - - attn_slice = this_attn_slice * (1 - base_attn_slice_mask) + \ - base_attn_slice * base_attn_slice_mask - else: - if dim is None: - attn_slice = self.last_attn_slice - #print("took whole slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape) - elif dim == 0: - attn_slice = self.last_attn_slice[start:end] - #print("took dim 0 slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape) - elif dim == 1: - attn_slice = self.last_attn_slice[:, start:end] - #print("took dim 1 slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape) - - if self.save_last_attn_slice: - if dim is None: - self.last_attn_slice = attn_slice - elif dim == 0: - # dynamically grow last_attn_slice if needed - if self.last_attn_slice is None: - self.last_attn_slice = attn_slice - #print("no last_attn_slice: shape now", self.last_attn_slice.shape) - elif self.last_attn_slice.shape[0] == start: - self.last_attn_slice = torch.cat([self.last_attn_slice, attn_slice], dim=0) - assert(self.last_attn_slice.shape[0] == end) - #print("last_attn_slice too small, appended dim 0 shape", attn_slice.shape, ", shape now", self.last_attn_slice.shape) - else: - # no need to grow - self.last_attn_slice[start:end] = attn_slice - #print("last_attn_slice shape is fine, setting dim 0 shape", attn_slice.shape, ", shape now", self.last_attn_slice.shape) - - elif dim == 1: - # dynamically grow last_attn_slice if needed - if self.last_attn_slice is None: - self.last_attn_slice = attn_slice - elif self.last_attn_slice.shape[1] == start: - self.last_attn_slice = torch.cat([self.last_attn_slice, attn_slice], dim=1) - assert(self.last_attn_slice.shape[1] == end) - else: - # no need to grow - self.last_attn_slice[:, start:end] = attn_slice - - if self.use_last_attn_weights and self.last_attn_slice_weights is not None: - if dim is None: - weights = self.last_attn_slice_weights - elif dim == 0: - weights = self.last_attn_slice_weights[start:end] - elif dim == 1: - weights = self.last_attn_slice_weights[:, start:end] - attn_slice = attn_slice * weights - - return attn_slice - - for name, module in unet.named_modules(): - module_name = type(module).__name__ - if module_name == "CrossAttention": - module.last_attn_slice = None - module.last_attn_slice_indices = None - module.last_attn_slice_mask = None - module.use_last_attn_weights = False - module.use_last_attn_slice = False - module.save_last_attn_slice = False - module.set_attention_slice_wrangler(attention_slice_wrangler) - - @classmethod - def remove_attention_function(cls, unet): - for name, module in unet.named_modules(): - module_name = type(module).__name__ - if module_name == "CrossAttention": - module.set_attention_slice_wrangler(None) - - -# original code below - -# Functions supporting Cross-Attention Control -# Copied from https://github.com/bloc97/CrossAttentionControl - -from difflib import SequenceMatcher - -import torch - - -def prompt_token(prompt, index, clip_tokenizer): - tokens = clip_tokenizer(prompt, - padding='max_length', - max_length=clip_tokenizer.model_max_length, - truncation=True, - return_tensors='pt', - return_overflowing_tokens=True - ).input_ids[0] - return clip_tokenizer.decode(tokens[index:index + 1]) - - -def use_last_tokens_attention(unet, use=True): - for name, module in unet.named_modules(): - module_name = type(module).__name__ - if module_name == 'CrossAttention' and 'attn2' in name: - module.use_last_attn_slice = use - - -def use_last_tokens_attention_weights(unet, use=True): - for name, module in unet.named_modules(): - module_name = type(module).__name__ - if module_name == 'CrossAttention' and 'attn2' in name: - module.use_last_attn_weights = use - - -def use_last_self_attention(unet, use=True): - for name, module in unet.named_modules(): - module_name = type(module).__name__ - if module_name == 'CrossAttention' and 'attn1' in name: - module.use_last_attn_slice = use - - -def save_last_tokens_attention(unet, save=True): - for name, module in unet.named_modules(): - module_name = type(module).__name__ - if module_name == 'CrossAttention' and 'attn2' in name: - module.save_last_attn_slice = save - - -def save_last_self_attention(unet, save=True): - for name, module in unet.named_modules(): - module_name = type(module).__name__ - if module_name == 'CrossAttention' and 'attn1' in name: - module.save_last_attn_slice = save From 04d93f044514096274911fa2878a5b0467823ac3 Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Sun, 23 Oct 2022 16:26:50 +0200 Subject: [PATCH 39/54] for k* samplers, estimate step_index from sigma --- ldm/models/diffusion/cross_attention_control.py | 8 +++----- ldm/models/diffusion/ksampler.py | 4 ++-- ldm/models/diffusion/plms.py | 3 +-- ldm/models/diffusion/shared_invokeai_diffusion.py | 15 +++++++++++++-- 4 files changed, 19 insertions(+), 11 deletions(-) diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index 905803ccfa..6e873f1c6d 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -80,19 +80,17 @@ class CrossAttentionControl: TOKENS = 2 @classmethod - def get_active_cross_attention_control_types_for_step(cls, context: 'CrossAttentionControl.Context', step_index:int=None)\ + def get_active_cross_attention_control_types_for_step(cls, context: 'CrossAttentionControl.Context', percent_through:float=None)\ -> list['CrossAttentionControl.CrossAttentionType']: """ Should cross-attention control be applied on the given step? - :param step_index: The step index (counts upwards from 0), or None if unknown. + :param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0. :return: A list of attention types that cross-attention control should be performed for on the given step. May be []. """ - if step_index is None: + if percent_through is None: return [cls.CrossAttentionType.SELF, cls.CrossAttentionType.TOKENS] opts = context.arguments.edit_options - # percent_through will never reach 1.0 (but this is intended) - percent_through = float(step_index)/float(context.step_count) to_control = [] if opts['s_start'] <= percent_through and percent_through < opts['s_end']: to_control.append(cls.CrossAttentionType.SELF) diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index 7bf48c62e8..2f5bf53850 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -42,9 +42,9 @@ class CFGDenoiser(nn.Module): self.invokeai_diffuser.remove_cross_attention_control() - def forward(self, x, sigma, uncond, cond, cond_scale, step_index): + def forward(self, x, sigma, uncond, cond, cond_scale): - next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale, step_index) + next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale) # apply threshold if self.warmup < self.warmup_max: diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py index 5b4674f28d..40c2631bcd 100644 --- a/ldm/models/diffusion/plms.py +++ b/ldm/models/diffusion/plms.py @@ -60,9 +60,8 @@ class PLMSSampler(Sampler): # damian0815 would like to know when/if this code path is used e_t = self.model.apply_model(x, t, c) else: - # step_index is expected to count up while index counts down + # step_index counts in the opposite direction to index step_index = step_count-(index+1) - # note that step_index == 0 is evaluated twice with different x e_t = self.invokeai_diffuser.do_diffusion_step(x, t, unconditional_conditioning, c, unconditional_guidance_scale, diff --git a/ldm/models/diffusion/shared_invokeai_diffusion.py b/ldm/models/diffusion/shared_invokeai_diffusion.py index 507feacaa9..56550094a8 100644 --- a/ldm/models/diffusion/shared_invokeai_diffusion.py +++ b/ldm/models/diffusion/shared_invokeai_diffusion.py @@ -57,7 +57,8 @@ class InvokeAIDiffuserComponent: def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor, unconditioning: torch.Tensor, conditioning: torch.Tensor, unconditional_guidance_scale: float, - step_index: int=None): + step_index: int=None + ): """ :param x: Current latents :param sigma: aka t, passed to the internal model to control how much denoising will occur @@ -72,7 +73,17 @@ class InvokeAIDiffuserComponent: cross_attention_control_types_to_do = [] if self.cross_attention_control_context is not None: - cross_attention_control_types_to_do = CrossAttentionControl.get_active_cross_attention_control_types_for_step(self.cross_attention_control_context, step_index) + if step_index is not None: + # percent_through will never reach 1.0 (but this is intended) + percent_through = float(step_index) / float(self.cross_attention_control_context.step_count) + else: + # find the current sigma in the sigma sequence + sigma_index = torch.nonzero(self.model.sigmas <= sigma)[-1] + # flip because sigmas[0] is for the fully denoised image + # percent_through must be <1 + percent_through = 1.0 - float(sigma_index.item() + 1) / float(self.model.sigmas.shape[0]) + print('estimated percent_through', percent_through, 'from sigma', sigma) + cross_attention_control_types_to_do = CrossAttentionControl.get_active_cross_attention_control_types_for_step(self.cross_attention_control_context, percent_through) if len(cross_attention_control_types_to_do)==0: print('step', step_index, ': not doing cross attention control') From 8f35819ddf86543916d48f3d68a4dab2fb92d23b Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Sun, 23 Oct 2022 19:38:31 +0200 Subject: [PATCH 40/54] add shape_freedom arg to .swap() --- ldm/invoke/prompt_parser.py | 16 +++++++++++++++- .../diffusion/shared_invokeai_diffusion.py | 5 +++-- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/ldm/invoke/prompt_parser.py b/ldm/invoke/prompt_parser.py index f5b369bc48..e8d43b8afd 100644 --- a/ldm/invoke/prompt_parser.py +++ b/ldm/invoke/prompt_parser.py @@ -1,3 +1,4 @@ +import math import string from typing import Union @@ -121,15 +122,28 @@ class CrossAttentionControlSubstitute(CrossAttentionControlledFragment): def __init__(self, original: Union[Fragment, list], edited: Union[Fragment, list], options: dict=None): self.original = original self.edited = edited + default_options = { 's_start': 0.0, - 's_end': 1.0, + 's_end': 0.3, # gives best results 't_start': 0.0, 't_end': 1.0 } merged_options = default_options if options is not None: + shape_freedom = options.pop('shape_freedom', None) + if shape_freedom is not None: + # high shape freedom = SD can do what it wants with the shape of the object + # high shape freedom => s_end = 0 + # low shape freedom => s_end = 1 + # shape freedom is in a "linear" space, while noticeable changes to s_end are typically closer around 0, + # and there is very little perceptible difference as s_end increases above 0.5 + # so for shape_freedom = 0.5 we probably want s_end to be 0.2 + # -> cube root and subtract from 1.0 + merged_options.s_end = 1.0 - math.cbrt(shape_freedom) + print('converted shape_freedom argument to', merged_options) merged_options.update(options) + self.options = merged_options def __repr__(self): diff --git a/ldm/models/diffusion/shared_invokeai_diffusion.py b/ldm/models/diffusion/shared_invokeai_diffusion.py index 56550094a8..85b594b8f7 100644 --- a/ldm/models/diffusion/shared_invokeai_diffusion.py +++ b/ldm/models/diffusion/shared_invokeai_diffusion.py @@ -78,6 +78,7 @@ class InvokeAIDiffuserComponent: percent_through = float(step_index) / float(self.cross_attention_control_context.step_count) else: # find the current sigma in the sigma sequence + # todo: this doesn't work with k_dpm_2 because the sigma used jumps around in the sequence sigma_index = torch.nonzero(self.model.sigmas <= sigma)[-1] # flip because sigmas[0] is for the fully denoised image # percent_through must be <1 @@ -86,14 +87,14 @@ class InvokeAIDiffuserComponent: cross_attention_control_types_to_do = CrossAttentionControl.get_active_cross_attention_control_types_for_step(self.cross_attention_control_context, percent_through) if len(cross_attention_control_types_to_do)==0: - print('step', step_index, ': not doing cross attention control') + #print('step', step_index, ': not doing cross attention control') # faster batched path x_twice = torch.cat([x]*2) sigma_twice = torch.cat([sigma]*2) both_conditionings = torch.cat([unconditioning, conditioning]) unconditioned_next_x, conditioned_next_x = self.model_forward_callback(x_twice, sigma_twice, both_conditionings).chunk(2) else: - print('step', step_index, ': doing cross attention control on', cross_attention_control_types_to_do) + #print('step', step_index, ': doing cross attention control on', cross_attention_control_types_to_do) # slower non-batched path (20% slower on mac MPS) # We are only interested in using attention maps for conditioned_next_x, but batching them with generation of # unconditioned_next_x causes attention maps to *also* be saved for the unconditioned_next_x. From 9210bf7d3af0f6d7657cb00a476cad0ef1cb9422 Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Sun, 23 Oct 2022 19:40:00 +0200 Subject: [PATCH 41/54] also parse shape_freedom keyword --- ldm/invoke/prompt_parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ldm/invoke/prompt_parser.py b/ldm/invoke/prompt_parser.py index e8d43b8afd..7cf638dfa5 100644 --- a/ldm/invoke/prompt_parser.py +++ b/ldm/invoke/prompt_parser.py @@ -470,7 +470,7 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float) pp.Word(pp.printables, exclude_chars=string.whitespace + '.').set_parse_action(make_text_fragment) + pp.FollowedBy(".swap") ]) # support keyword=number arguments - cross_attention_option_keyword = pp.Or([pp.Keyword("s_start"), pp.Keyword("s_end"), pp.Keyword("t_start"), pp.Keyword("t_end")]) + cross_attention_option_keyword = pp.Or([pp.Keyword("s_start"), pp.Keyword("s_end"), pp.Keyword("t_start"), pp.Keyword("t_end"), pp.Keyword("shape_freedom")]) cross_attention_option = pp.Group(cross_attention_option_keyword + pp.Literal("=").suppress() + number) edited_fragment = pp.MatchFirst([ lparen + From 8e7d744c6053617d7c0ab1a3ed30cfe48af24d02 Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Sun, 23 Oct 2022 19:43:35 +0200 Subject: [PATCH 42/54] fix bad math --- ldm/invoke/prompt_parser.py | 3 +-- ldm/models/diffusion/shared_invokeai_diffusion.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/ldm/invoke/prompt_parser.py b/ldm/invoke/prompt_parser.py index 7cf638dfa5..56b6bc7c42 100644 --- a/ldm/invoke/prompt_parser.py +++ b/ldm/invoke/prompt_parser.py @@ -1,4 +1,3 @@ -import math import string from typing import Union @@ -140,7 +139,7 @@ class CrossAttentionControlSubstitute(CrossAttentionControlledFragment): # and there is very little perceptible difference as s_end increases above 0.5 # so for shape_freedom = 0.5 we probably want s_end to be 0.2 # -> cube root and subtract from 1.0 - merged_options.s_end = 1.0 - math.cbrt(shape_freedom) + merged_options['s_end'] = 1.0 - shape_freedom ** (1. / 3.) print('converted shape_freedom argument to', merged_options) merged_options.update(options) diff --git a/ldm/models/diffusion/shared_invokeai_diffusion.py b/ldm/models/diffusion/shared_invokeai_diffusion.py index 85b594b8f7..c14e71be8d 100644 --- a/ldm/models/diffusion/shared_invokeai_diffusion.py +++ b/ldm/models/diffusion/shared_invokeai_diffusion.py @@ -1,4 +1,3 @@ -from enum import Enum from math import ceil from typing import Callable, Optional @@ -83,7 +82,7 @@ class InvokeAIDiffuserComponent: # flip because sigmas[0] is for the fully denoised image # percent_through must be <1 percent_through = 1.0 - float(sigma_index.item() + 1) / float(self.model.sigmas.shape[0]) - print('estimated percent_through', percent_through, 'from sigma', sigma) + print('estimated percent_through', percent_through, 'from sigma', sigma.item()) cross_attention_control_types_to_do = CrossAttentionControl.get_active_cross_attention_control_types_for_step(self.cross_attention_control_context, percent_through) if len(cross_attention_control_types_to_do)==0: From f7cd98c2386c770d3ef0422be2b258caebfc5573 Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Sun, 23 Oct 2022 20:38:28 +0200 Subject: [PATCH 43/54] tweak default cross-attention values --- ldm/invoke/prompt_parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ldm/invoke/prompt_parser.py b/ldm/invoke/prompt_parser.py index 56b6bc7c42..d5ebd18dfc 100644 --- a/ldm/invoke/prompt_parser.py +++ b/ldm/invoke/prompt_parser.py @@ -124,7 +124,7 @@ class CrossAttentionControlSubstitute(CrossAttentionControlledFragment): default_options = { 's_start': 0.0, - 's_end': 0.3, # gives best results + 's_end': 0.206, # ~= shape_freedom=0.5 't_start': 0.0, 't_end': 1.0 } From b0eb864a259149cca1667dc6ddc0c136062d251f Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Sun, 23 Oct 2022 23:01:53 +0200 Subject: [PATCH 44/54] move attention weighting operations to postfix --- ldm/invoke/prompt_parser.py | 112 +++++++++++++++----------- tests/test_prompt_parser.py | 152 +++++++++++++++++++----------------- 2 files changed, 147 insertions(+), 117 deletions(-) diff --git a/ldm/invoke/prompt_parser.py b/ldm/invoke/prompt_parser.py index d5ebd18dfc..48d08a7908 100644 --- a/ldm/invoke/prompt_parser.py +++ b/ldm/invoke/prompt_parser.py @@ -353,9 +353,34 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float) else: raise PromptParser.ParsingException("Cannot make fragment from " + str(x)) + def build_escaped_word_parser(escaped_chars_to_ignore: str): + terms = [] + for c in escaped_chars_to_ignore: + terms.append(pp.Literal('\\'+c)) + terms.append( + #pp.CharsNotIn(string.whitespace + escaped_chars_to_ignore, exact=1) + pp.Word(pp.printables, exclude_chars=string.whitespace + escaped_chars_to_ignore) + ) + return pp.Combine(pp.OneOrMore( + pp.MatchFirst(terms) + )) + + def build_escaped_word_parser_charbychar(escaped_chars_to_ignore: str): + escapes = [] + for c in escaped_chars_to_ignore: + escapes.append(pp.Literal('\\'+c)) + return pp.Combine(pp.OneOrMore( + pp.MatchFirst(escapes + [pp.CharsNotIn( + string.whitespace + escaped_chars_to_ignore, + exact=1 + )]) + )) + + + def parse_fragment_str(x, in_quotes: bool=False, in_parens: bool=False): + #print(f"parsing fragment string \"{x}\"") fragment_string = x[0] - #print(f"parsing fragment string \"{fragment_string}\"") if len(fragment_string.strip()) == 0: return Fragment('') @@ -401,59 +426,55 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float) parenthesized_fragment.set_name('parenthesized_fragment').set_debug(False) debug_attention = False - # attention control of the form +(phrase) / -(phrase) / (phrase) + # attention control of the form (phrase)+ / (phrase)+ / (phrase) # phrase can be multiple words, can have multiple +/- signs to increase the effect or type a floating point or integer weight - attention_head = (number | pp.Word('+') | pp.Word('-'))\ - .set_name("attention_head")\ - .set_debug(False) - word_inside_attention = pp.Combine(pp.OneOrMore( - pp.Literal('\\)') | pp.Literal('\\(') | pp.Literal('\\"') | - pp.Word(pp.printables, exclude_chars=string.whitespace + '\\()"') - )).set_name('word_inside_attention') attention_with_parens = pp.Forward() + attention_without_parens = pp.Forward() - attention_with_parens_delimited_list = pp.OneOrMore(pp.Or([ - quoted_fragment.copy().set_debug(debug_attention), - attention.copy().set_debug(debug_attention), - cross_attention_substitute, - word_inside_attention.set_debug(debug_attention) - #pp.White() - ]).set_name('delim_inner').set_debug(debug_attention)) - # have to disable ignore_expr here to prevent pyparsing from stripping off quote marks - attention_with_parens_body = pp.nested_expr(content=attention_with_parens_delimited_list, - ignore_expr=None#((pp.Literal("\\(") | pp.Literal('\\)'))) - ) - attention_with_parens_body.set_debug(debug_attention) - attention_with_parens << (attention_head + attention_with_parens_body) + attention_with_parens_foot = (number | pp.Word('+') | pp.Word('-'))\ + .set_name("attention_foot")\ + .set_debug(False) + attention_with_parens <<= pp.Group( + lparen + + pp.ZeroOrMore(quoted_fragment | attention_with_parens | parenthesized_fragment | cross_attention_substitute | attention_without_parens | + (pp.Empty() + build_escaped_word_parser_charbychar('()')).set_name('undecorated_word').set_debug(debug_attention)#.set_parse_action(lambda t: t[0]) + ) + + rparen + attention_with_parens_foot) attention_with_parens.set_name('attention_with_parens').set_debug(debug_attention) - attention_without_parens = (pp.Word('+') | pp.Word('-')) + (quoted_fragment | word_inside_attention) + attention_without_parens_foot = pp.Or(pp.Word('+') | pp.Word('-')).set_name('attention_without_parens_foots') + attention_without_parens <<= pp.Group( + (quoted_fragment.copy().set_name('attention_quoted_fragment_without_parens').set_debug(debug_attention) + attention_without_parens_foot) | + pp.Combine(build_escaped_word_parser_charbychar('()+-')).set_name('attention_word_without_parens').set_debug(debug_attention)#.set_parse_action(lambda x: print('escapéd', x)) + + attention_without_parens_foot)#.leave_whitespace() attention_without_parens.set_name('attention_without_parens').set_debug(debug_attention) - attention << (attention_with_parens | attention_without_parens) + + attention << pp.MatchFirst([attention_with_parens, + attention_without_parens + ]) attention.set_name('attention') def make_attention(x): - #print("making Attention from", x) - weight = 1 - # number(str) - if type(x[0]) is float or type(x[0]) is int: - weight = float(x[0]) - # +(str) or -(str) or +str or -str - elif type(x[0]) is str: - base = attention_plus_base if x[0][0] == '+' else attention_minus_base - weight = pow(base, len(x[0])) - if type(x[1]) is list or type(x[1]) is pp.ParseResults: - return Attention(weight=weight, children=[(Fragment(x) if type(x) is str else x) for x in x[1]]) - elif type(x[1]) is str: - return Attention(weight=weight, children=[Fragment(x[1])]) - elif type(x[1]) is Fragment: - return Attention(weight=weight, children=[x[1]]) - raise PromptParser.ParsingException(f"Don't know how to make attention with children {x[1]}") + #print("entered make_attention with", x) + children = x[0][:-1] + weight_raw = x[0][-1] + weight = 1.0 + if type(weight_raw) is float or type(weight_raw) is int: + weight = weight_raw + elif type(weight_raw) is str: + base = attention_plus_base if weight_raw[0] == '+' else attention_minus_base + weight = pow(base, len(weight_raw)) + + #print("making Attention from", children, "with weight", weight) + + return Attention(weight=weight, children=[(Fragment(x) if type(x) is str else x) for x in children]) attention_with_parens.set_parse_action(make_attention) attention_without_parens.set_parse_action(make_attention) + #print("parsing test:", attention_with_parens.parse_string("mountain (man)1.1")) + # cross-attention control empty_string = ((lparen + rparen) | pp.Literal('""').suppress() | @@ -487,10 +508,10 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float) cross_attention_substitute.set_name('cross_attention_substitute').set_debug(debug_cross_attention_control) def make_cross_attention_substitute(x): - print("making cacs for", x[0], "->", x[1], "with options", x.as_dict()) + #print("making cacs for", x[0], "->", x[1], "with options", x.as_dict()) #if len(x>2): cacs = CrossAttentionControlSubstitute(x[0], x[1], options=x.as_dict()) - print("made", cacs) + #print("made", cacs) return cacs cross_attention_substitute.set_parse_action(make_cross_attention_substitute) @@ -511,10 +532,11 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float) (quotes + pp.ZeroOrMore(pp.Word(string.whitespace)) + quotes)).set_debug(False).set_name('empty') # root prompt definition - prompt = ((pp.OneOrMore(prompt_part | quoted_fragment) | empty) + pp.StringEnd()) \ + prompt = (pp.OneOrMore(pp.Or([prompt_part, quoted_fragment, empty])) + pp.StringEnd()) \ .set_parse_action(lambda x: Prompt(x)) - + #print("parsing test:", prompt.parse_string("spaced eyes--")) + #print("parsing test:", prompt.parse_string("eyes--")) # weighted blend of prompts # ("promptA", "promptB").blend(a, b) where "promptA" and "promptB" are valid prompts and a and b are float or @@ -536,7 +558,7 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float) quoted_prompt = pp.dbl_quoted_string.set_parse_action(make_prompt_from_quoted_string) quoted_prompt.set_name('quoted_prompt') - debug_blend=True + debug_blend=False blend_terms = pp.delimited_list(quoted_prompt).set_name('blend_terms').set_debug(debug_blend) blend_weights = (pp.delimited_list(number) + pp.Optional(pp.Char(",").suppress() + "no_normalize")).set_name('blend_weights').set_debug(debug_blend) blend = pp.Group(lparen + pp.Group(blend_terms) + rparen diff --git a/tests/test_prompt_parser.py b/tests/test_prompt_parser.py index 02644012d8..203b95ddf0 100644 --- a/tests/test_prompt_parser.py +++ b/tests/test_prompt_parser.py @@ -34,27 +34,28 @@ class PromptParserTestCase(unittest.TestCase): self.assertEqual(make_weighted_conjunction([("fire, flames , fire", 1)]), parse_prompt("fire, flames , fire")) def test_attention(self): - self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("0.5(flames)")) - self.assertEqual(make_weighted_conjunction([('fire flames', 0.5)]), parse_prompt("0.5(fire flames)")) - self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("+(flames)")) - self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("-(flames)")) - self.assertEqual(make_weighted_conjunction([('fire', 1), ('flames', 0.5)]), parse_prompt("fire 0.5(flames)")) - self.assertEqual(make_weighted_conjunction([('flames', pow(1.1, 2))]), parse_prompt("++(flames)")) - self.assertEqual(make_weighted_conjunction([('flames', pow(0.9, 2))]), parse_prompt("--(flames)")) - self.assertEqual(make_weighted_conjunction([('flowers', pow(0.9, 3)), ('flames', pow(1.1, 3))]), parse_prompt("---(flowers) +++flames")) - self.assertEqual(make_weighted_conjunction([('flowers', pow(0.9, 3)), ('flames', pow(1.1, 3))]), parse_prompt("---(flowers) +++flames")) - self.assertEqual(make_weighted_conjunction([('flowers', pow(0.9, 3)), ('flames+', pow(1.1, 3))]), - parse_prompt("---(flowers) +++flames+")) + self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("(flames)0.5")) + self.assertEqual(make_weighted_conjunction([('fire flames', 0.5)]), parse_prompt("(fire flames)0.5")) + self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("(flames)+")) + self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("flames+")) + self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("\"flames\"+")) + self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("(flames)-")) + self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("flames-")) + self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("\"flames\"-")) + self.assertEqual(make_weighted_conjunction([('fire', 1), ('flames', 0.5)]), parse_prompt("fire (flames)0.5")) + self.assertEqual(make_weighted_conjunction([('flames', pow(1.1, 2))]), parse_prompt("(flames)++")) + self.assertEqual(make_weighted_conjunction([('flames', pow(0.9, 2))]), parse_prompt("(flames)--")) + self.assertEqual(make_weighted_conjunction([('flowers', pow(0.9, 3)), ('flames', pow(1.1, 3))]), parse_prompt("(flowers)--- flames+++")) self.assertEqual(make_weighted_conjunction([('pretty flowers', 1.1)]), - parse_prompt("+(pretty flowers)")) + parse_prompt("(pretty flowers)+")) self.assertEqual(make_weighted_conjunction([('pretty flowers', 1.1), (', the flames are too hot', 1)]), - parse_prompt("+(pretty flowers), the flames are too hot")) + parse_prompt("(pretty flowers)+, the flames are too hot")) def test_no_parens_attention_runon(self): - self.assertEqual(make_weighted_conjunction([('fire', pow(1.1, 2)), ('flames', 1.0)]), parse_prompt("++fire flames")) - self.assertEqual(make_weighted_conjunction([('fire', pow(0.9, 2)), ('flames', 1.0)]), parse_prompt("--fire flames")) - self.assertEqual(make_weighted_conjunction([('flowers', 1.0), ('fire', pow(1.1, 2)), ('flames', 1.0)]), parse_prompt("flowers ++fire flames")) - self.assertEqual(make_weighted_conjunction([('flowers', 1.0), ('fire', pow(0.9, 2)), ('flames', 1.0)]), parse_prompt("flowers --fire flames")) + self.assertEqual(make_weighted_conjunction([('fire', 1.0), ('flames', pow(1.1, 2))]), parse_prompt("fire flames++")) + self.assertEqual(make_weighted_conjunction([('fire', 1.0), ('flames', pow(0.9, 2))]), parse_prompt("fire flames--")) + self.assertEqual(make_weighted_conjunction([('flowers', 1.0), ('fire', pow(1.1, 2)), ('flames', 1.0)]), parse_prompt("flowers fire++ flames")) + self.assertEqual(make_weighted_conjunction([('flowers', 1.0), ('fire', pow(0.9, 2)), ('flames', 1.0)]), parse_prompt("flowers fire-- flames")) def test_explicit_conjunction(self): @@ -62,7 +63,7 @@ class PromptParserTestCase(unittest.TestCase): self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])]), parse_prompt('("fire", "flames").and()')) self.assertEqual( Conjunction([FlattenedPrompt([('fire flames', 1.0)]), FlattenedPrompt([('mountain man', 1.0)])]), parse_prompt('("fire flames", "mountain man").and()')) - self.assertEqual(Conjunction([FlattenedPrompt([('fire', 2.0)]), FlattenedPrompt([('flames', 0.9)])]), parse_prompt('("2.0(fire)", "-flames").and()')) + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 2.0)]), FlattenedPrompt([('flames', 0.9)])]), parse_prompt('("(fire)2.0", "flames-").and()')) self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)]), FlattenedPrompt([('mountain man', 1.0)])]), parse_prompt('("fire", "flames", "mountain man").and()')) @@ -75,8 +76,11 @@ class PromptParserTestCase(unittest.TestCase): parse_prompt('("fire", "flames").and(2,1,2)') def test_complex_conjunction(self): + + #print(parse_prompt("a person with a hat (riding a bicycle.swap(skateboard))++")) + self.assertEqual(Conjunction([FlattenedPrompt([("mountain man", 1.0)]), FlattenedPrompt([("a person with a hat", 1.0), ("riding a bicycle", pow(1.1,2))])], weights=[0.5, 0.5]), - parse_prompt("(\"mountain man\", \"a person with a hat ++(riding a bicycle)\").and(0.5, 0.5)")) + parse_prompt("(\"mountain man\", \"a person with a hat (riding a bicycle)++\").and(0.5, 0.5)")) self.assertEqual(Conjunction([FlattenedPrompt([("mountain man", 1.0)]), FlattenedPrompt([("a person with a hat", 1.0), ("riding a", 1.1*1.1), @@ -85,7 +89,7 @@ class PromptParserTestCase(unittest.TestCase): [Fragment("skateboard", pow(1.1,2))]) ]) ], weights=[0.5, 0.5]), - parse_prompt("(\"mountain man\", \"a person with a hat ++(riding a bicycle.swap(skateboard))\").and(0.5, 0.5)")) + parse_prompt("(\"mountain man\", \"a person with a hat (riding a bicycle.swap(skateboard))++\").and(0.5, 0.5)")) def test_badly_formed(self): def make_untouched_prompt(prompt): @@ -95,24 +99,25 @@ class PromptParserTestCase(unittest.TestCase): self.assertEqual(make_untouched_prompt(prompt), parse_prompt(prompt)) assert_if_prompt_string_not_untouched('a test prompt') - assert_if_prompt_string_not_untouched('a badly formed test+ prompt') + # todo handle this + #assert_if_prompt_string_not_untouched('a badly formed +test prompt') with self.assertRaises(pyparsing.ParseException): parse_prompt('a badly (formed test prompt') #with self.assertRaises(pyparsing.ParseException): with self.assertRaises(pyparsing.ParseException): - parse_prompt('a badly (formed test+ prompt') + parse_prompt('a badly (formed +test prompt') with self.assertRaises(pyparsing.ParseException): - parse_prompt('a badly (formed test+ )prompt') + parse_prompt('a badly (formed +test )prompt') with self.assertRaises(pyparsing.ParseException): - parse_prompt('a badly (formed test+ )prompt') + parse_prompt('a badly (formed +test )prompt') with self.assertRaises(pyparsing.ParseException): - parse_prompt('(((a badly (formed test+ )prompt') + parse_prompt('(((a badly (formed +test )prompt') with self.assertRaises(pyparsing.ParseException): - parse_prompt('(a (ba)dly (f)ormed test+ prompt') + parse_prompt('(a (ba)dly (f)ormed +test prompt') with self.assertRaises(pyparsing.ParseException): - parse_prompt('(a (ba)dly (f)ormed test+ +prompt') + parse_prompt('(a (ba)dly (f)ormed +test +prompt') with self.assertRaises(pyparsing.ParseException): - parse_prompt('("((a badly (formed test+ ").blend(1.0)') + parse_prompt('("((a badly (formed +test ").blend(1.0)') def test_blend(self): @@ -129,7 +134,7 @@ class PromptParserTestCase(unittest.TestCase): FlattenedPrompt([('fire flames', 1.0), ('hot', pow(1.1, 2))]), FlattenedPrompt([('hi', 1.0)])], weights=[0.7, 0.3, 1.0])]), - parse_prompt("(\"fire\", \"fire flames ++(hot)\", \"hi\").blend(0.7, 0.3, 1.0)") + parse_prompt("(\"fire\", \"fire flames (hot)++\", \"hi\").blend(0.7, 0.3, 1.0)") ) # blend a single entry is not a failure self.assertEqual(Conjunction([Blend([FlattenedPrompt([('fire', 1.0)])], [0.7])]), @@ -156,17 +161,17 @@ class PromptParserTestCase(unittest.TestCase): self.assertEqual( Conjunction([Blend([FlattenedPrompt([('mountain, man, hairy', 1)]), FlattenedPrompt([('face, teeth,', 1), ('eyes', 0.9*0.9)])], weights=[1.0,-1.0])]), - parse_prompt('("mountain, man, hairy", "face, teeth, --eyes").blend(1,-1)') + parse_prompt('("mountain, man, hairy", "face, teeth, eyes--").blend(1,-1)') ) def test_nested(self): self.assertEqual(make_weighted_conjunction([('fire', 1.0), ('flames', 2.0), ('trees', 3.0)]), - parse_prompt('fire 2.0(flames 1.5(trees))')) + parse_prompt('fire (flames (trees)1.5)2.0')) self.assertEqual(Conjunction([Blend(prompts=[FlattenedPrompt([('fire', 1.0), ('flames', 1.2100000000000002)]), FlattenedPrompt([('mountain', 1.0), ('man', 2.0)])], weights=[1.0, 1.0])]), - parse_prompt('("fire ++(flames)", "mountain 2(man)").blend(1,1)')) + parse_prompt('("fire (flames)++", "mountain (man)2").blend(1,1)')) def test_cross_attention_control(self): @@ -237,15 +242,15 @@ class PromptParserTestCase(unittest.TestCase): flames_to_trees_fire = Conjunction([FlattenedPrompt([ CrossAttentionControlSubstitute([Fragment('flames',0.5)], [Fragment('trees',0.7)]), Fragment(',', 1), Fragment('fire', 2.0)])]) - self.assertEqual(flames_to_trees_fire, parse_prompt('"0.5(flames)".swap("0.7(trees)"), 2.0(fire)')) + self.assertEqual(flames_to_trees_fire, parse_prompt('"(flames)0.5".swap("(trees)0.7"), (fire)2.0')) flames_to_trees_fire = Conjunction([FlattenedPrompt([ CrossAttentionControlSubstitute([Fragment('fire',0.5), Fragment('flames',0.25)], [Fragment('trees',0.7)]), Fragment(',', 1), Fragment('fire', 2.0)])]) - self.assertEqual(flames_to_trees_fire, parse_prompt('"0.5(fire 0.5(flames))".swap("0.7(trees)"), 2.0(fire)')) + self.assertEqual(flames_to_trees_fire, parse_prompt('"(fire (flames)0.5)0.5".swap("(trees)0.7"), (fire)2.0')) flames_to_trees_fire = Conjunction([FlattenedPrompt([ CrossAttentionControlSubstitute([Fragment('fire',0.5), Fragment('flames',0.25)], [Fragment('trees',0.7), Fragment('houses', 1)]), Fragment(',', 1), Fragment('fire', 2.0)])]) - self.assertEqual(flames_to_trees_fire, parse_prompt('"0.5(fire 0.5(flames))".swap("0.7(trees) houses"), 2.0(fire)')) + self.assertEqual(flames_to_trees_fire, parse_prompt('"(fire (flames)0.5)0.5".swap("(trees)0.7 houses"), (fire)2.0')) def test_cross_attention_control_options(self): self.assertEqual(Conjunction([ @@ -271,48 +276,48 @@ class PromptParserTestCase(unittest.TestCase): self.assertEqual(make_basic_conjunction(['mountain (man)']),parse_prompt('mountain \(man\)')) self.assertEqual(make_basic_conjunction(['mountain (man )']),parse_prompt('mountain (\(man)\)')) self.assertEqual(make_basic_conjunction(['mountain (man)']),parse_prompt('mountain (\(man\))')) - self.assertEqual(make_weighted_conjunction([('mountain', 1), ('(man)', 1.1)]), parse_prompt('mountain +(\(man\))')) - self.assertEqual(make_weighted_conjunction([('mountain', 1), ('(man)', 1.1)]), parse_prompt('"mountain" +(\(man\))')) - self.assertEqual(make_weighted_conjunction([('"mountain"', 1), ('(man)', 1.1)]), parse_prompt('\\"mountain\\" +(\(man\))')) + self.assertEqual(make_weighted_conjunction([('mountain', 1), ('(man)', 1.1)]), parse_prompt('mountain (\(man\))+')) + self.assertEqual(make_weighted_conjunction([('mountain', 1), ('(man)', 1.1)]), parse_prompt('"mountain" (\(man\))+')) + self.assertEqual(make_weighted_conjunction([('"mountain"', 1), ('(man)', 1.1)]), parse_prompt('\\"mountain\\" (\(man\))+')) # same weights for each are combined into one - self.assertEqual(make_weighted_conjunction([('"mountain" (man)', 1.1)]), parse_prompt('+(\\"mountain\\") +(\(man\))')) - self.assertEqual(make_weighted_conjunction([('"mountain"', 1.1), ('(man)', 0.9)]), parse_prompt('+(\\"mountain\\") -(\(man\))')) + self.assertEqual(make_weighted_conjunction([('"mountain" (man)', 1.1)]), parse_prompt('(\\"mountain\\")+ (\(man\))+')) + self.assertEqual(make_weighted_conjunction([('"mountain"', 1.1), ('(man)', 0.9)]), parse_prompt('(\\"mountain\\")+ (\(man\))-')) - self.assertEqual(make_weighted_conjunction([('mountain', 1), ('\(man\)', 1.1)]),parse_prompt('mountain 1.1(\(man\))')) - self.assertEqual(make_weighted_conjunction([('mountain', 1), ('\(man\)', 1.1)]),parse_prompt('"mountain" 1.1(\(man\))')) - self.assertEqual(make_weighted_conjunction([('"mountain"', 1), ('\(man\)', 1.1)]),parse_prompt('\\"mountain\\" 1.1(\(man\))')) + self.assertEqual(make_weighted_conjunction([('mountain', 1), ('\(man\)', 1.1)]),parse_prompt('mountain (\(man\))1.1')) + self.assertEqual(make_weighted_conjunction([('mountain', 1), ('\(man\)', 1.1)]),parse_prompt('"mountain" (\(man\))1.1')) + self.assertEqual(make_weighted_conjunction([('"mountain"', 1), ('\(man\)', 1.1)]),parse_prompt('\\"mountain\\" (\(man\))1.1')) # same weights for each are combined into one - self.assertEqual(make_weighted_conjunction([('\\"mountain\\" \(man\)', 1.1)]),parse_prompt('+(\\"mountain\\") 1.1(\(man\))')) - self.assertEqual(make_weighted_conjunction([('\\"mountain\\"', 1.1), ('\(man\)', 0.9)]),parse_prompt('1.1(\\"mountain\\") 0.9(\(man\))')) + self.assertEqual(make_weighted_conjunction([('\\"mountain\\" \(man\)', 1.1)]),parse_prompt('(\\"mountain\\")+ (\(man\))1.1')) + self.assertEqual(make_weighted_conjunction([('\\"mountain\\"', 1.1), ('\(man\)', 0.9)]),parse_prompt('(\\"mountain\\")1.1 (\(man\))0.9')) - self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain', 1.1), ('\(man\)', 1.1*1.1)]),parse_prompt('hairy +(mountain +(\(man\)))')) - self.assertEqual(make_weighted_conjunction([('hairy', 1), ('\(man\)', 1.1*1.1), ('mountain', 1.1)]),parse_prompt('hairy +(1.1(\(man\)) "mountain")')) - self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain', 1.1), ('\(man\)', 1.1*1.1)]),parse_prompt('hairy +("mountain" 1.1(\(man\)) )')) - self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, man', 1.1)]),parse_prompt('hairy +("mountain, man")')) - self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, man with a', 1.1), ('beard', 1.1*1.1)]), parse_prompt('hairy +("mountain, man" with a +beard)')) - self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, man with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy +("mountain, man" with a 2.0(beard))')) - self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man\" with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy +("mountain, \\"man\\"" with a 2.0(beard))')) - self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, m\"an\" with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy +("mountain, m\\"an\\"" with a 2.0(beard))')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain', 1.1), ('\(man\)', 1.1*1.1)]),parse_prompt('hairy (mountain (\(man\))+)+')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('\(man\)', 1.1*1.1), ('mountain', 1.1)]),parse_prompt('hairy ((\(man\))1.1 "mountain")+')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain', 1.1), ('\(man\)', 1.1*1.1)]),parse_prompt('hairy ("mountain" (\(man\))1.1 )+')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, man', 1.1)]),parse_prompt('hairy ("mountain, man")+')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, man with a', 1.1), ('beard', 1.1*1.1)]), parse_prompt('hairy ("mountain, man" with a beard+)+')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, man with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, man" with a (beard)2.0)+')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man\" with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\"man\\"" with a (beard)2.0)+')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, m\"an\" with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, m\\"an\\"" with a (beard)2.0)+')) - self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man (with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy +("mountain, \\\"man\" \(with a 2.0(beard))')) - self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man w(ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy +("mountain, \\\"man\" w\(ith a 2.0(beard))')) - self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man with( a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy +("mountain, \\\"man\" with\( a 2.0(beard))')) - self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man )with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy +("mountain, \\\"man\" \)with a 2.0(beard))')) - self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy +("mountain, \\\"man\" w\)ith a 2.0(beard))')) - self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man with) a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy +("mountain, \\\"man\" with\) a 2.0(beard))')) - self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mou)ntain, \"man (wit(h a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy +("mou\)ntain, \\\"man\" \(wit\(h a 2.0(beard))')) - self.assertEqual(make_weighted_conjunction([('hai(ry', 1), ('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hai\(ry +("mountain, \\\"man\" w\)ith a 2.0(beard))')) - self.assertEqual(make_weighted_conjunction([('hairy((', 1), ('mountain, \"man with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy\(\( +("mountain, \\\"man\" with a 2.0(beard))')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man (with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" \(with a (beard)2.0)+')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man w(ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" w\(ith a (beard)2.0)+')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man with( a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" with\( a (beard)2.0)+')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man )with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" \)with a (beard)2.0)+')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" w\)ith a (beard)2.0)+')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man with) a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" with\) a (beard)2.0)+')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mou)ntain, \"man (wit(h a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mou\)ntain, \\\"man\" \(wit\(h a (beard)2.0)+')) + self.assertEqual(make_weighted_conjunction([('hai(ry', 1), ('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hai\(ry ("mountain, \\\"man\" w\)ith a (beard)2.0)+')) + self.assertEqual(make_weighted_conjunction([('hairy((', 1), ('mountain, \"man with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy\(\( ("mountain, \\\"man\" with a (beard)2.0)+')) - self.assertEqual(make_weighted_conjunction([('mountain, \"man (with a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('+("mountain, \\\"man\" \(with a 2.0(beard)) hairy')) - self.assertEqual(make_weighted_conjunction([('mountain, \"man w(ith a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('+("mountain, \\\"man\" w\(ith a 2.0(beard))hairy')) - self.assertEqual(make_weighted_conjunction([('mountain, \"man with( a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('+("mountain, \\\"man\" with\( a 2.0(beard)) hairy')) - self.assertEqual(make_weighted_conjunction([('mountain, \"man )with a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('+("mountain, \\\"man\" \)with a 2.0(beard)) hairy')) - self.assertEqual(make_weighted_conjunction([('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('+("mountain, \\\"man\" w\)ith a 2.0(beard)) hairy')) - self.assertEqual(make_weighted_conjunction([('mountain, \"man with) a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt(' +("mountain, \\\"man\" with\) a 2.0(beard)) hairy')) - self.assertEqual(make_weighted_conjunction([('mou)ntain, \"man (wit(h a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('+("mou\)ntain, \\\"man\" \(wit\(h a 2.0(beard)) hairy')) - self.assertEqual(make_weighted_conjunction([('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0), ('hai(ry', 1)]), parse_prompt('+("mountain, \\\"man\" w\)ith a 2.0(beard)) hai\(ry ')) - self.assertEqual(make_weighted_conjunction([('mountain, \"man with a', 1.1), ('beard', 1.1*2.0), ('hairy((', 1)]), parse_prompt('+("mountain, \\\"man\" with a 2.0(beard)) hairy\(\( ')) + self.assertEqual(make_weighted_conjunction([('mountain, \"man (with a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" \(with a (beard)2.0)+ hairy')) + self.assertEqual(make_weighted_conjunction([('mountain, \"man w(ith a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" w\(ith a (beard)2.0)+hairy')) + self.assertEqual(make_weighted_conjunction([('mountain, \"man with( a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" with\( a (beard)2.0)+ hairy')) + self.assertEqual(make_weighted_conjunction([('mountain, \"man )with a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" \)with a (beard)2.0)+ hairy')) + self.assertEqual(make_weighted_conjunction([('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" w\)ith a (beard)2.0)+ hairy')) + self.assertEqual(make_weighted_conjunction([('mountain, \"man with) a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt(' ("mountain, \\\"man\" with\) a (beard)2.0)+ hairy')) + self.assertEqual(make_weighted_conjunction([('mou)ntain, \"man (wit(h a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mou\)ntain, \\\"man\" \(wit\(h a (beard)2.0)+ hairy')) + self.assertEqual(make_weighted_conjunction([('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0), ('hai(ry', 1)]), parse_prompt('("mountain, \\\"man\" w\)ith a (beard)2.0)+ hai\(ry ')) + self.assertEqual(make_weighted_conjunction([('mountain, \"man with a', 1.1), ('beard', 1.1*2.0), ('hairy((', 1)]), parse_prompt('("mountain, \\\"man\" with a (beard)2.0)+ hairy\(\( ')) def test_cross_attention_escaping(self): @@ -339,7 +344,10 @@ class PromptParserTestCase(unittest.TestCase): parse_prompt('mountain (\(\(\().swap(m\(on\)\)key)')) def test_single(self): - self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mou)ntain, \"man (wit(h a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy +("mou\)ntain, \\\"man\" \(wit\(h a 2.0(beard))')) + # todo handle this + #self.assertEqual(make_basic_conjunction(['a badly formed +test prompt']), + # parse_prompt('a badly formed +test prompt')) + pass if __name__ == '__main__': From 92c6a3812dcc8be178ceac4d9580aad9937094d0 Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Mon, 24 Oct 2022 00:06:53 +0200 Subject: [PATCH 45/54] catch fewer exceptions in prompt2image --- ldm/generate.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ldm/generate.py b/ldm/generate.py index 39f0b06759..f6d5a12ebf 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -1,5 +1,5 @@ # Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein) - +import pyparsing # Derived from source code carrying the following copyrights # Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors @@ -24,6 +24,7 @@ from PIL import Image, ImageOps from torch import nn from pytorch_lightning import seed_everything, logging +from ldm.invoke.prompt_parser import PromptParser from ldm.util import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler @@ -474,7 +475,7 @@ class Generate: print('**Interrupted** Partial results will be returned.') else: raise KeyboardInterrupt - except (RuntimeError, Exception) as e: + except RuntimeError as e: print(traceback.format_exc(), file=sys.stderr) print('>> Could not generate image.') From 2619a0b28641c3370095df9606caf332f1fbf16a Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Mon, 24 Oct 2022 00:22:14 +0200 Subject: [PATCH 46/54] allow longer substitutions without quotes for cross attention swap --- ldm/invoke/prompt_parser.py | 2 +- tests/test_prompt_parser.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/ldm/invoke/prompt_parser.py b/ldm/invoke/prompt_parser.py index 48d08a7908..830c5313e3 100644 --- a/ldm/invoke/prompt_parser.py +++ b/ldm/invoke/prompt_parser.py @@ -495,7 +495,7 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float) edited_fragment = pp.MatchFirst([ lparen + (quoted_fragment | - pp.Group(pp.Word(pp.printables, exclude_chars=string.whitespace + ',').set_parse_action(make_text_fragment)) + pp.Group(pp.OneOrMore(pp.Word(pp.printables, exclude_chars=string.whitespace + ',').set_parse_action(make_text_fragment))) ) + pp.Dict(pp.OneOrMore(comma + cross_attention_option)) + rparen, diff --git a/tests/test_prompt_parser.py b/tests/test_prompt_parser.py index 203b95ddf0..0c4d9106db 100644 --- a/tests/test_prompt_parser.py +++ b/tests/test_prompt_parser.py @@ -269,6 +269,14 @@ class PromptParserTestCase(unittest.TestCase): Fragment('eating a hotdog', 1)])]), parse_prompt("a \"cat\".swap(dog, t_start=0.1, s_start=20) eating a hotdog")) + self.assertEqual( + Conjunction([ + FlattenedPrompt([Fragment('a fantasy forest landscape', 1), + CrossAttentionControlSubstitute([Fragment('', 1)], [Fragment('with a river', 1)], + options={'s_start': 0.8, 't_start': 0.8})])]), + parse_prompt("a fantasy forest landscape \"\".swap(with a river, s_start=0.8, t_start=0.8)")) + + def test_escaping(self): # make sure ", ( and ) can be escaped From ee4273d760e642cc33f15f85e0b82eeaaa8cc375 Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Mon, 24 Oct 2022 01:23:43 +0200 Subject: [PATCH 47/54] fix step count on ddim --- ldm/models/diffusion/sampler.py | 1 + ldm/models/diffusion/shared_invokeai_diffusion.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/ldm/models/diffusion/sampler.py b/ldm/models/diffusion/sampler.py index 8099997bb3..79c15717fe 100644 --- a/ldm/models/diffusion/sampler.py +++ b/ldm/models/diffusion/sampler.py @@ -359,6 +359,7 @@ class Sampler(object): unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, t_next = ts_next, + step_count=total_steps ) x_dec, pred_x0, e_t = outs diff --git a/ldm/models/diffusion/shared_invokeai_diffusion.py b/ldm/models/diffusion/shared_invokeai_diffusion.py index c14e71be8d..4bf5688586 100644 --- a/ldm/models/diffusion/shared_invokeai_diffusion.py +++ b/ldm/models/diffusion/shared_invokeai_diffusion.py @@ -86,14 +86,14 @@ class InvokeAIDiffuserComponent: cross_attention_control_types_to_do = CrossAttentionControl.get_active_cross_attention_control_types_for_step(self.cross_attention_control_context, percent_through) if len(cross_attention_control_types_to_do)==0: - #print('step', step_index, ': not doing cross attention control') + print('pct', percent_through, ': not doing cross attention control') # faster batched path x_twice = torch.cat([x]*2) sigma_twice = torch.cat([sigma]*2) both_conditionings = torch.cat([unconditioning, conditioning]) unconditioned_next_x, conditioned_next_x = self.model_forward_callback(x_twice, sigma_twice, both_conditionings).chunk(2) else: - #print('step', step_index, ': doing cross attention control on', cross_attention_control_types_to_do) + print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do) # slower non-batched path (20% slower on mac MPS) # We are only interested in using attention maps for conditioned_next_x, but batching them with generation of # unconditioned_next_x causes attention maps to *also* be saved for the unconditioned_next_x. From cc2042bd4ca8a237e8e473563daa6e65fa431215 Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Mon, 24 Oct 2022 01:43:35 +0200 Subject: [PATCH 48/54] keep the effect of _start and _end arguments consistent across k* and other samplers --- ldm/invoke/generator/img2img.py | 3 ++- ldm/models/diffusion/cross_attention_control.py | 4 ++++ ldm/models/diffusion/ddim.py | 3 ++- ldm/models/diffusion/plms.py | 3 ++- ldm/models/diffusion/sampler.py | 7 ++++--- 5 files changed, 14 insertions(+), 6 deletions(-) diff --git a/ldm/invoke/generator/img2img.py b/ldm/invoke/generator/img2img.py index 4942bcc0c3..e12f5e5e79 100644 --- a/ldm/invoke/generator/img2img.py +++ b/ldm/invoke/generator/img2img.py @@ -51,7 +51,8 @@ class Img2Img(Generator): unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc, init_latent = self.init_latent, # changes how noising is performed in ksampler - extra_conditioning_info = extra_conditioning_info + extra_conditioning_info = extra_conditioning_info, + all_timesteps_count = steps ) return self.sample_to_image(samples) diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index 6e873f1c6d..1e5b073a3d 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -29,6 +29,10 @@ class CrossAttentionControl: class Context: def __init__(self, arguments: 'CrossAttentionControl.Arguments', step_count: int): + """ + :param arguments: Arguments for the cross-attention control process + :param step_count: The absolute total number of steps of diffusion (for img2img this is likely larger than the number of steps that will actually run) + """ self.arguments = arguments self.step_count = step_count diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py index 5b5dfaf4af..b11e8578e7 100644 --- a/ldm/models/diffusion/ddim.py +++ b/ldm/models/diffusion/ddim.py @@ -16,9 +16,10 @@ class DDIMSampler(Sampler): super().prepare_to_sample(t_enc, **kwargs) extra_conditioning_info = kwargs.get('extra_conditioning_info', None) + all_timesteps_count = kwargs.get('all_timesteps_count', t_enc) if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: - self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = t_enc) + self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = all_timesteps_count) else: self.invokeai_diffuser.remove_cross_attention_control() diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py index 40c2631bcd..6bd519b63b 100644 --- a/ldm/models/diffusion/plms.py +++ b/ldm/models/diffusion/plms.py @@ -21,9 +21,10 @@ class PLMSSampler(Sampler): super().prepare_to_sample(t_enc, **kwargs) extra_conditioning_info = kwargs.get('extra_conditioning_info', None) + all_timesteps_count = kwargs.get('all_timesteps_count', t_enc) if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: - self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = t_enc) + self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = all_timesteps_count) else: self.invokeai_diffuser.remove_cross_attention_control() diff --git a/ldm/models/diffusion/sampler.py b/ldm/models/diffusion/sampler.py index 79c15717fe..853702ef68 100644 --- a/ldm/models/diffusion/sampler.py +++ b/ldm/models/diffusion/sampler.py @@ -235,7 +235,7 @@ class Sampler(object): dynamic_ncols=True, ) old_eps = [] - self.prepare_to_sample(t_enc=total_steps,**kwargs) + self.prepare_to_sample(t_enc=total_steps,all_timesteps_count=steps,**kwargs) img = self.get_initial_image(x_T,shape,total_steps) # probably don't need this at all @@ -310,6 +310,7 @@ class Sampler(object): use_original_steps=False, init_latent = None, mask = None, + all_timesteps_count = None, **kwargs ): @@ -327,7 +328,7 @@ class Sampler(object): iterator = tqdm(time_range, desc='Decoding image', total=total_steps) x_dec = x_latent x0 = init_latent - self.prepare_to_sample(t_enc=total_steps,**kwargs) + self.prepare_to_sample(t_enc=total_steps, all_timesteps_count=all_timesteps_count, **kwargs) for i, step in enumerate(iterator): index = total_steps - i - 1 @@ -359,7 +360,7 @@ class Sampler(object): unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, t_next = ts_next, - step_count=total_steps + step_count=len(self.ddim_timesteps) ) x_dec, pred_x0, e_t = outs From 1fb15d5c8126e46c9ecfa2538393327ce798d858 Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Mon, 24 Oct 2022 02:02:42 +0200 Subject: [PATCH 49/54] fix hires fix --- ldm/invoke/generator/txt2img2img.py | 5 +++-- ldm/models/diffusion/shared_invokeai_diffusion.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/ldm/invoke/generator/txt2img2img.py b/ldm/invoke/generator/txt2img2img.py index 5808f7bdb2..c8438d1f7f 100644 --- a/ldm/invoke/generator/txt2img2img.py +++ b/ldm/invoke/generator/txt2img2img.py @@ -23,7 +23,7 @@ class Txt2Img2Img(Generator): Return value depends on the seed at the time you call it kwargs are 'width' and 'height' """ - uc, c, extra_conditioing_info = conditioning + uc, c, extra_conditioning_info = conditioning @torch.no_grad() def make_image(x_T): @@ -96,7 +96,8 @@ class Txt2Img2Img(Generator): img_callback = step_callback, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc, - extra_conditioning_info = extra_conditioning_info + # cross-attention control is disabled during upscale + #extra_conditioning_info = None ) if self.free_gpu_mem: diff --git a/ldm/models/diffusion/shared_invokeai_diffusion.py b/ldm/models/diffusion/shared_invokeai_diffusion.py index 4bf5688586..f52dd46766 100644 --- a/ldm/models/diffusion/shared_invokeai_diffusion.py +++ b/ldm/models/diffusion/shared_invokeai_diffusion.py @@ -86,7 +86,7 @@ class InvokeAIDiffuserComponent: cross_attention_control_types_to_do = CrossAttentionControl.get_active_cross_attention_control_types_for_step(self.cross_attention_control_context, percent_through) if len(cross_attention_control_types_to_do)==0: - print('pct', percent_through, ': not doing cross attention control') + print('not doing cross attention control') # faster batched path x_twice = torch.cat([x]*2) sigma_twice = torch.cat([sigma]*2) From 63902f3d3412376a779acbef57acf1826ef7c0fd Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Mon, 24 Oct 2022 02:08:55 +0200 Subject: [PATCH 50/54] also apply conditioing during hires fix upscale --- ldm/invoke/generator/txt2img2img.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ldm/invoke/generator/txt2img2img.py b/ldm/invoke/generator/txt2img2img.py index c8438d1f7f..a10aad58ac 100644 --- a/ldm/invoke/generator/txt2img2img.py +++ b/ldm/invoke/generator/txt2img2img.py @@ -96,8 +96,8 @@ class Txt2Img2Img(Generator): img_callback = step_callback, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc, - # cross-attention control is disabled during upscale - #extra_conditioning_info = None + extra_conditioning_info=extra_conditioning_info, + all_timesteps_count=steps ) if self.free_gpu_mem: From 0564397ee6d206a7ede5ff16628b4b46578339c4 Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Mon, 24 Oct 2022 11:16:43 +0200 Subject: [PATCH 51/54] cleanup logs --- ldm/models/diffusion/shared_invokeai_diffusion.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ldm/models/diffusion/shared_invokeai_diffusion.py b/ldm/models/diffusion/shared_invokeai_diffusion.py index f52dd46766..b8a7a04d0e 100644 --- a/ldm/models/diffusion/shared_invokeai_diffusion.py +++ b/ldm/models/diffusion/shared_invokeai_diffusion.py @@ -82,18 +82,18 @@ class InvokeAIDiffuserComponent: # flip because sigmas[0] is for the fully denoised image # percent_through must be <1 percent_through = 1.0 - float(sigma_index.item() + 1) / float(self.model.sigmas.shape[0]) - print('estimated percent_through', percent_through, 'from sigma', sigma.item()) + #print('estimated percent_through', percent_through, 'from sigma', sigma.item()) cross_attention_control_types_to_do = CrossAttentionControl.get_active_cross_attention_control_types_for_step(self.cross_attention_control_context, percent_through) if len(cross_attention_control_types_to_do)==0: - print('not doing cross attention control') + #print('not doing cross attention control') # faster batched path x_twice = torch.cat([x]*2) sigma_twice = torch.cat([sigma]*2) both_conditionings = torch.cat([unconditioning, conditioning]) unconditioned_next_x, conditioned_next_x = self.model_forward_callback(x_twice, sigma_twice, both_conditionings).chunk(2) else: - print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do) + #print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do) # slower non-batched path (20% slower on mac MPS) # We are only interested in using attention maps for conditioned_next_x, but batching them with generation of # unconditioned_next_x causes attention maps to *also* be saved for the unconditioned_next_x. From 44e4090909c0375f9344f647e85a2107f6b1ded0 Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Mon, 24 Oct 2022 11:16:52 +0200 Subject: [PATCH 52/54] re-enable legacy blend syntax --- backend/invoke_ai_web_server.py | 2 +- backend/server.py | 2 +- ldm/invoke/args.py | 2 +- ldm/invoke/conditioning.py | 70 ++++---------------------- ldm/invoke/prompt_parser.py | 88 ++++++++++++++++++++++++++++++--- tests/test_prompt_parser.py | 41 ++++++++++++++- 6 files changed, 134 insertions(+), 71 deletions(-) diff --git a/backend/invoke_ai_web_server.py b/backend/invoke_ai_web_server.py index 96ecda1af1..dabe072f80 100644 --- a/backend/invoke_ai_web_server.py +++ b/backend/invoke_ai_web_server.py @@ -14,7 +14,7 @@ from threading import Event from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash from ldm.invoke.pngwriter import PngWriter, retrieve_metadata -from ldm.invoke.conditioning import split_weighted_subprompts +from ldm.invoke.prompt_parser import split_weighted_subprompts from backend.modules.parameters import parameters_to_command diff --git a/backend/server.py b/backend/server.py index 7b8a8a5a69..8ad861356c 100644 --- a/backend/server.py +++ b/backend/server.py @@ -33,7 +33,7 @@ from ldm.generate import Generate from ldm.invoke.restoration import Restoration from ldm.invoke.pngwriter import PngWriter, retrieve_metadata from ldm.invoke.args import APP_ID, APP_VERSION, calculate_init_img_hash -from ldm.invoke.conditioning import split_weighted_subprompts +from ldm.invoke.prompt_parser import split_weighted_subprompts from modules.parameters import parameters_to_command diff --git a/ldm/invoke/args.py b/ldm/invoke/args.py index 26920f28ea..12e9f96f6b 100644 --- a/ldm/invoke/args.py +++ b/ldm/invoke/args.py @@ -92,7 +92,7 @@ import copy import base64 import functools import ldm.invoke.pngwriter -from ldm.invoke.conditioning import split_weighted_subprompts +from ldm.invoke.prompt_parser import split_weighted_subprompts SAMPLER_CHOICES = [ 'ddim', diff --git a/ldm/invoke/conditioning.py b/ldm/invoke/conditioning.py index 52d40312ac..65459b5c5f 100644 --- a/ldm/invoke/conditioning.py +++ b/ldm/invoke/conditioning.py @@ -41,9 +41,15 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n pp = PromptParser() - # we don't support conjunctions for now - parsed_prompt: Union[FlattenedPrompt, Blend] = pp.parse(prompt_string_cleaned).prompts[0] - parsed_negative_prompt: FlattenedPrompt = pp.parse(unconditioned_words).prompts[0] + parsed_prompt: Union[FlattenedPrompt, Blend] = None + legacy_blend: Blend = pp.parse_legacy_blend(prompt_string_cleaned) + if legacy_blend is not None: + parsed_prompt = legacy_blend + else: + # we don't support conjunctions for now + parsed_prompt = pp.parse_conjunction(prompt_string_cleaned).prompts[0] + + parsed_negative_prompt: FlattenedPrompt = pp.parse_conjunction(unconditioned_words).prompts[0] print("parsed prompt to", parsed_prompt) conditioning = None @@ -146,61 +152,3 @@ def get_tokens_length(model, fragments: list[Fragment]): return sum([len(x) for x in tokens]) -def split_weighted_subprompts(text, skip_normalize=False)->list: - """ - grabs all text up to the first occurrence of ':' - uses the grabbed text as a sub-prompt, and takes the value following ':' as weight - if ':' has no value defined, defaults to 1.0 - repeats until no text remaining - """ - prompt_parser = re.compile(""" - (?P # capture group for 'prompt' - (?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:' - ) # end 'prompt' - (?: # non-capture group - :+ # match one or more ':' characters - (?P # capture group for 'weight' - -?\d+(?:\.\d+)? # match positive or negative integer or decimal number - )? # end weight capture group, make optional - \s* # strip spaces after weight - | # OR - $ # else, if no ':' then match end of line - ) # end non-capture group - """, re.VERBOSE) - parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float( - match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)] - if skip_normalize: - return parsed_prompts - weight_sum = sum(map(lambda x: x[1], parsed_prompts)) - if weight_sum == 0: - print( - "Warning: Subprompt weights add up to zero. Discarding and using even weights instead.") - equal_weight = 1 / max(len(parsed_prompts), 1) - return [(x[0], equal_weight) for x in parsed_prompts] - return [(x[0], x[1] / weight_sum) for x in parsed_prompts] - -# shows how the prompt is tokenized -# usually tokens have '' to indicate end-of-word, -# but for readability it has been replaced with ' ' -def log_tokenization(text, model, log=False, weight=1): - if not log: - return - tokens = model.cond_stage_model.tokenizer._tokenize(text) - tokenized = "" - discarded = "" - usedTokens = 0 - totalTokens = len(tokens) - for i in range(0, totalTokens): - token = tokens[i].replace('', ' ') - # alternate color - s = (usedTokens % 6) + 1 - if i < model.cond_stage_model.max_length: - tokenized = tokenized + f"\x1b[0;3{s};40m{token}" - usedTokens += 1 - else: # over max token length - discarded = discarded + f"\x1b[0;3{s};40m{token}" - print(f"\n>> Tokens ({usedTokens}), Weight ({weight:.2f}):\n{tokenized}\x1b[0m") - if discarded != "": - print( - f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m" - ) diff --git a/ldm/invoke/prompt_parser.py b/ldm/invoke/prompt_parser.py index 830c5313e3..3a96d664f0 100644 --- a/ldm/invoke/prompt_parser.py +++ b/ldm/invoke/prompt_parser.py @@ -1,6 +1,6 @@ import string -from typing import Union - +from typing import Union, Optional +import re import pyparsing as pp class Prompt(): @@ -223,10 +223,10 @@ class PromptParser(): def __init__(self, attention_plus_base=1.1, attention_minus_base=0.9): - self.root = build_parser_syntax(attention_plus_base, attention_minus_base) + self.conjunction, self.prompt = build_parser_syntax(attention_plus_base, attention_minus_base) - def parse(self, prompt: str) -> Conjunction: + def parse_conjunction(self, prompt: str) -> Conjunction: ''' :param prompt: The prompt string to parse :return: a Conjunction representing the parsed results. @@ -236,13 +236,25 @@ class PromptParser(): if len(prompt.strip()) == 0: return Conjunction(prompts=[FlattenedPrompt([('', 1.0)])], weights=[1.0]) - root = self.root.parse_string(prompt) + root = self.conjunction.parse_string(prompt) #print(f"'{prompt}' parsed to root", root) #fused = fuse_fragments(parts) #print("fused to", fused) return self.flatten(root[0]) + def parse_legacy_blend(self, text: str) -> Optional[Blend]: + weighted_subprompts = split_weighted_subprompts(text, skip_normalize=False) + if len(weighted_subprompts) == 1: + return None + strings = [x[0] for x in weighted_subprompts] + weights = [x[1] for x in weighted_subprompts] + + parsed_conjunctions = [self.parse_conjunction(x) for x in strings] + flattened_prompts = [x.prompts[0] for x in parsed_conjunctions] + + return Blend(prompts=flattened_prompts, weights=weights, normalize_weights=True) + def flatten(self, root: Conjunction) -> Conjunction: """ @@ -596,4 +608,68 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float) conjunction.set_debug(False) # top-level is a conjunction of one or more blends or prompts - return conjunction + return conjunction, prompt + + + +def split_weighted_subprompts(text, skip_normalize=False)->list: + """ + Legacy blend parsing. + + grabs all text up to the first occurrence of ':' + uses the grabbed text as a sub-prompt, and takes the value following ':' as weight + if ':' has no value defined, defaults to 1.0 + repeats until no text remaining + """ + prompt_parser = re.compile(""" + (?P # capture group for 'prompt' + (?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:' + ) # end 'prompt' + (?: # non-capture group + :+ # match one or more ':' characters + (?P # capture group for 'weight' + -?\d+(?:\.\d+)? # match positive or negative integer or decimal number + )? # end weight capture group, make optional + \s* # strip spaces after weight + | # OR + $ # else, if no ':' then match end of line + ) # end non-capture group + """, re.VERBOSE) + parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float( + match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)] + if skip_normalize: + return parsed_prompts + weight_sum = sum(map(lambda x: x[1], parsed_prompts)) + if weight_sum == 0: + print( + "Warning: Subprompt weights add up to zero. Discarding and using even weights instead.") + equal_weight = 1 / max(len(parsed_prompts), 1) + return [(x[0], equal_weight) for x in parsed_prompts] + return [(x[0], x[1] / weight_sum) for x in parsed_prompts] + + +# shows how the prompt is tokenized +# usually tokens have '' to indicate end-of-word, +# but for readability it has been replaced with ' ' +def log_tokenization(text, model, log=False, weight=1): + if not log: + return + tokens = model.cond_stage_model.tokenizer._tokenize(text) + tokenized = "" + discarded = "" + usedTokens = 0 + totalTokens = len(tokens) + for i in range(0, totalTokens): + token = tokens[i].replace('', 'x` ') + # alternate color + s = (usedTokens % 6) + 1 + if i < model.cond_stage_model.max_length: + tokenized = tokenized + f"\x1b[0;3{s};40m{token}" + usedTokens += 1 + else: # over max token length + discarded = discarded + f"\x1b[0;3{s};40m{token}" + print(f"\n>> Tokens ({usedTokens}), Weight ({weight:.2f}):\n{tokenized}\x1b[0m") + if discarded != "": + print( + f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m" + ) diff --git a/tests/test_prompt_parser.py b/tests/test_prompt_parser.py index 0c4d9106db..486265d2f5 100644 --- a/tests/test_prompt_parser.py +++ b/tests/test_prompt_parser.py @@ -9,7 +9,7 @@ from ldm.invoke.prompt_parser import PromptParser, Blend, Conjunction, Flattened def parse_prompt(prompt_string): pp = PromptParser() #print(f"parsing '{prompt_string}'") - parse_result = pp.parse(prompt_string) + parse_result = pp.parse_conjunction(prompt_string) #print(f"-> parsed '{prompt_string}' to {parse_result}") return parse_result @@ -351,6 +351,45 @@ class PromptParserTestCase(unittest.TestCase): self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('(((', 1)], [Fragment('m(on))key', 1)])])]), parse_prompt('mountain (\(\(\().swap(m\(on\)\)key)')) + def test_legacy_blend(self): + pp = PromptParser() + + self.assertEqual(Blend([FlattenedPrompt([('mountain man', 1)]), + FlattenedPrompt([('man mountain', 1)])], + weights=[0.5,0.5]), + pp.parse_legacy_blend('mountain man:1 man mountain:1')) + + self.assertEqual(Blend([FlattenedPrompt([('mountain', 1.1), ('man', 1)]), + FlattenedPrompt([('man', 1), ('mountain', 0.9)])], + weights=[0.5,0.5]), + pp.parse_legacy_blend('mountain+ man:1 man mountain-:1')) + + self.assertEqual(Blend([FlattenedPrompt([('mountain', 1.1), ('man', 1)]), + FlattenedPrompt([('man', 1), ('mountain', 0.9)])], + weights=[0.5,0.5]), + pp.parse_legacy_blend('mountain+ man:1 man mountain-')) + + self.assertEqual(Blend([FlattenedPrompt([('mountain', 1.1), ('man', 1)]), + FlattenedPrompt([('man', 1), ('mountain', 0.9)])], + weights=[0.5,0.5]), + pp.parse_legacy_blend('mountain+ man: man mountain-:')) + + self.assertEqual(Blend([FlattenedPrompt([('mountain man', 1)]), + FlattenedPrompt([('man mountain', 1)])], + weights=[0.75,0.25]), + pp.parse_legacy_blend('mountain man:3 man mountain:1')) + + self.assertEqual(Blend([FlattenedPrompt([('mountain man', 1)]), + FlattenedPrompt([('man mountain', 1)])], + weights=[1.0,0.0]), + pp.parse_legacy_blend('mountain man:3 man mountain:0')) + + self.assertEqual(Blend([FlattenedPrompt([('mountain man', 1)]), + FlattenedPrompt([('man mountain', 1)])], + weights=[0.8,0.2]), + pp.parse_legacy_blend('"mountain man":4 man mountain')) + + def test_single(self): # todo handle this #self.assertEqual(make_basic_conjunction(['a badly formed +test prompt']), From 61a4897b71432e5a6d3307fb112f6023e4278409 Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Mon, 24 Oct 2022 11:49:47 +0200 Subject: [PATCH 53/54] re-enable tokenization logging --- ldm/invoke/conditioning.py | 29 ++++++++++++++++++++++------- ldm/invoke/prompt_parser.py | 5 +++++ 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/ldm/invoke/conditioning.py b/ldm/invoke/conditioning.py index 65459b5c5f..7c095de7b7 100644 --- a/ldm/invoke/conditioning.py +++ b/ldm/invoke/conditioning.py @@ -50,7 +50,7 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n parsed_prompt = pp.parse_conjunction(prompt_string_cleaned).prompts[0] parsed_negative_prompt: FlattenedPrompt = pp.parse_conjunction(unconditioned_words).prompts[0] - print("parsed prompt to", parsed_prompt) + print(f">> Parsed prompt to {parsed_prompt}") conditioning = None cac_args:CrossAttentionControl.Arguments = None @@ -59,7 +59,7 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n blend: Blend = parsed_prompt embeddings_to_blend = None for flattened_prompt in blend.prompts: - this_embedding, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt) + this_embedding, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt, log_tokens=log_tokens) embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat( (embeddings_to_blend, this_embedding)) conditioning = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0), @@ -103,14 +103,14 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n edit_options.append(None) original_token_count += count edited_token_count += count - original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, original_prompt) + original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, original_prompt, log_tokens=log_tokens) # naïvely building a single edited_embeddings like this disregards the effects of changing the absolute location of # subsequent tokens when there is >1 edit and earlier edits change the total token count. # eg "a cat.swap(smiling dog, s_start=0.5) eating a hotdog.swap(pizza)" - when the 'pizza' edit is active but the # 'cat' edit is not, the 'pizza' feature vector will nevertheless be affected by the introduction of the extra # token 'smiling' in the inactive 'cat' edit. # todo: build multiple edited_embeddings, one for each edit, and pass just the edited fragments through to the CrossAttentionControl functions - edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, edited_prompt) + edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, edited_prompt, log_tokens=log_tokens) conditioning = original_embeddings edited_conditioning = edited_embeddings @@ -121,10 +121,10 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n edit_options = edit_options ) else: - conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt) + conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt, log_tokens=log_tokens) - unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, parsed_negative_prompt) + unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, parsed_negative_prompt, log_tokens=log_tokens) return ( unconditioning, conditioning, InvokeAIDiffuserComponent.ExtraConditioningInfo( cross_attention_control_args=cac_args @@ -138,12 +138,27 @@ def build_token_edit_opcodes(original_tokens, edited_tokens): return SequenceMatcher(None, original_tokens, edited_tokens).get_opcodes() -def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt): +def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt, log_tokens: bool=False): if type(flattened_prompt) is not FlattenedPrompt: raise Exception(f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead") fragments = [x.text for x in flattened_prompt.children] weights = [x.weight for x in flattened_prompt.children] embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True, fragment_weights=[weights]) + if not flattened_prompt.is_empty and log_tokens: + start_token = model.cond_stage_model.tokenizer.bos_token_id + end_token = model.cond_stage_model.tokenizer.eos_token_id + tokens_list = tokens[0].tolist() + if tokens_list[0] == start_token: + tokens_list[0] = '' + try: + first_end_token_index = tokens_list.index(end_token) + tokens_list[first_end_token_index] = '' + tokens_list = tokens_list[:first_end_token_index+1] + except ValueError: + pass + + print(f">> Prompt fragments {fragments}, tokenized to \n{tokens_list}") + return embeddings, tokens def get_tokens_length(model, fragments: list[Fragment]): diff --git a/ldm/invoke/prompt_parser.py b/ldm/invoke/prompt_parser.py index 3a96d664f0..6709f48066 100644 --- a/ldm/invoke/prompt_parser.py +++ b/ldm/invoke/prompt_parser.py @@ -51,6 +51,11 @@ class FlattenedPrompt(): raise PromptParser.ParsingException( f"FlattenedPrompt cannot contain {fragment}, only Fragments or (str, float) tuples are allowed") + @property + def is_empty(self): + return len(self.children) == 0 or \ + (len(self.children) == 1 and len(self.children[0].text) == 0) + def __repr__(self): return f"FlattenedPrompt:{self.children}" def __eq__(self, other): From d12ae3bab0719c1e9f5aa925768407012538a97c Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Mon, 24 Oct 2022 14:58:38 +0200 Subject: [PATCH 54/54] documentation for new prompt syntax --- docs/features/PROMPTS.md | 42 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/docs/features/PROMPTS.md b/docs/features/PROMPTS.md index b5ef26858b..8fdb97b7b8 100644 --- a/docs/features/PROMPTS.md +++ b/docs/features/PROMPTS.md @@ -84,6 +84,48 @@ Getting close - but there's no sense in having a saddle when our horse doesn't h --- +## **Prompt Syntax Features** + +The InvokeAI prompting language has the following features: + +### Attention weighting +Append a word or phrase with `-` or `+`, or a weight between `0` and `2` (`1`=default), to decrease or increase "attention" (= a mix of per-token CFG weighting multiplier and, for `-`, a weighted blend with the prompt without the term). + +The following will be recognised: + * single words without parentheses: `a tall thin man picking apricots+` + * single or multiple words with parentheses: `a tall thin man picking (apricots)+` `a tall thin man picking (apricots)-` `a tall thin man (picking apricots)+` `a tall thin man (picking apricots)-` + * more effect with more symbols `a tall thin man (picking apricots)++` + * nesting `a tall thin man (picking apricots+)++` (`apricots` effectively gets `+++`) + * all of the above with explicit numbers `a tall thin man picking (apricots)1.1` `a tall thin man (picking (apricots)1.3)1.1`. (`+` is equivalent to 1.1, `++` is pow(1.1,2), `+++` is pow(1.1,3), etc; `-` means 0.9, `--` means pow(0.9,2), etc.) + * attention also applies to `[unconditioning]` so `a tall thin man picking apricots [(ladder)0.01]` will *very gently* nudge SD away from trying to draw the man on a ladder + +### Blending between prompts + +* `("a tall thin man picking apricots", "a tall thin man picking pears").blend(1,1)` +* The existing prompt blending using `:` will continue to be supported - `("a tall thin man picking apricots", "a tall thin man picking pears").blend(1,1)` is equivalent to `a tall thin man picking apricots:1 a tall thin man picking pears:1` in the old syntax. +* Attention weights can be nested inside blends. +* Non-normalized blends are supported by passing `no_normalize` as an additional argument to the blend weights, eg `("a tall thin man picking apricots", "a tall thin man picking pears").blend(1,-1,no_normalize)`. very fun to explore local maxima in the feature space, but also easy to produce garbage output. + +See the section below on "Prompt Blending" for more information about how this works. + +### Cross-Attention Control ('prompt2prompt') + +Denoise with a given prompt and then re-use the attention→pixel maps to substitute words in the original prompt for words in a new prompt. Based off [bloc97's colab](https://github.com/bloc97/CrossAttentionControl). + +* `a ("fluffy cat").swap("smiling dog") eating a hotdog`. + * quotes optional: `a (fluffy cat).swap(smiling dog) eating a hotdog`. + * for single word substitutions parentheses are also optional: `a cat.swap(dog) eating a hotdog`. +* Supports options `s_start`, `s_end`, `t_start`, `t_end` (each 0-1) loosely corresponding to bloc97's `prompt_edit_spatial_start/_end` and `prompt_edit_tokens_start/_end` but with the math swapped to make it easier to intuitively understand. + * Example usage:`a (cat).swap(dog, s_end=0.3) eating a hotdog` - the `s_end` argument means that the "spatial" (self-attention) edit will stop having any effect after 30% (=0.3) of the steps have been done, leaving Stable Diffusion with 70% of the steps where it is free to decide for itself how to reshape the cat-form into a dog form. + * The numbers represent a percentage through the step sequence where the edits should happen. 0 means the start (noisy starting image), 1 is the end (final image). + * For img2img, the step sequence does not start at 0 but instead at (1-strength) - so if strength is 0.7, s_start and s_end must both be greater than 0.3 (1-0.7) to have any effect. +* Convenience option `shape_freedom` (0-1) to specify how much "freedom" Stable Diffusion should have to change the shape of the subject being swapped. + * `a (cat).swap(dog, shape_freedom=0.5) eating a hotdog`. + +### Escaping parantheses () and speech marks "" + +If the model you are using has parentheses () or speech marks "" as part of its syntax, you will need to "escape" these using a backslash, so that`(my_keyword)` becomes `\(my_keyword\)`. Otherwise, the prompt parser will attempt to interpret the parentheses as part of the prompt syntax and it will get confused. + ## **Prompt Blending** You may blend together different sections of the prompt to explore the