diff --git a/backend/invoke_ai_web_server.py b/backend/invoke_ai_web_server.py index 96ecda1af1..dabe072f80 100644 --- a/backend/invoke_ai_web_server.py +++ b/backend/invoke_ai_web_server.py @@ -14,7 +14,7 @@ from threading import Event from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash from ldm.invoke.pngwriter import PngWriter, retrieve_metadata -from ldm.invoke.conditioning import split_weighted_subprompts +from ldm.invoke.prompt_parser import split_weighted_subprompts from backend.modules.parameters import parameters_to_command diff --git a/backend/server.py b/backend/server.py index 7b8a8a5a69..8ad861356c 100644 --- a/backend/server.py +++ b/backend/server.py @@ -33,7 +33,7 @@ from ldm.generate import Generate from ldm.invoke.restoration import Restoration from ldm.invoke.pngwriter import PngWriter, retrieve_metadata from ldm.invoke.args import APP_ID, APP_VERSION, calculate_init_img_hash -from ldm.invoke.conditioning import split_weighted_subprompts +from ldm.invoke.prompt_parser import split_weighted_subprompts from modules.parameters import parameters_to_command diff --git a/configs/stable-diffusion/v1-inference.yaml b/configs/stable-diffusion/v1-inference.yaml index 9c773077b6..baf91f6e26 100644 --- a/configs/stable-diffusion/v1-inference.yaml +++ b/configs/stable-diffusion/v1-inference.yaml @@ -76,4 +76,4 @@ model: target: torch.nn.Identity cond_stage_config: - target: ldm.modules.encoders.modules.FrozenCLIPEmbedder + target: ldm.modules.encoders.modules.WeightedFrozenCLIPEmbedder diff --git a/docs/features/PROMPTS.md b/docs/features/PROMPTS.md index b5ef26858b..8fdb97b7b8 100644 --- a/docs/features/PROMPTS.md +++ b/docs/features/PROMPTS.md @@ -84,6 +84,48 @@ Getting close - but there's no sense in having a saddle when our horse doesn't h --- +## **Prompt Syntax Features** + +The InvokeAI prompting language has the following features: + +### Attention weighting +Append a word or phrase with `-` or `+`, or a weight between `0` and `2` (`1`=default), to decrease or increase "attention" (= a mix of per-token CFG weighting multiplier and, for `-`, a weighted blend with the prompt without the term). + +The following will be recognised: + * single words without parentheses: `a tall thin man picking apricots+` + * single or multiple words with parentheses: `a tall thin man picking (apricots)+` `a tall thin man picking (apricots)-` `a tall thin man (picking apricots)+` `a tall thin man (picking apricots)-` + * more effect with more symbols `a tall thin man (picking apricots)++` + * nesting `a tall thin man (picking apricots+)++` (`apricots` effectively gets `+++`) + * all of the above with explicit numbers `a tall thin man picking (apricots)1.1` `a tall thin man (picking (apricots)1.3)1.1`. (`+` is equivalent to 1.1, `++` is pow(1.1,2), `+++` is pow(1.1,3), etc; `-` means 0.9, `--` means pow(0.9,2), etc.) + * attention also applies to `[unconditioning]` so `a tall thin man picking apricots [(ladder)0.01]` will *very gently* nudge SD away from trying to draw the man on a ladder + +### Blending between prompts + +* `("a tall thin man picking apricots", "a tall thin man picking pears").blend(1,1)` +* The existing prompt blending using `:` will continue to be supported - `("a tall thin man picking apricots", "a tall thin man picking pears").blend(1,1)` is equivalent to `a tall thin man picking apricots:1 a tall thin man picking pears:1` in the old syntax. +* Attention weights can be nested inside blends. +* Non-normalized blends are supported by passing `no_normalize` as an additional argument to the blend weights, eg `("a tall thin man picking apricots", "a tall thin man picking pears").blend(1,-1,no_normalize)`. very fun to explore local maxima in the feature space, but also easy to produce garbage output. + +See the section below on "Prompt Blending" for more information about how this works. + +### Cross-Attention Control ('prompt2prompt') + +Denoise with a given prompt and then re-use the attention→pixel maps to substitute words in the original prompt for words in a new prompt. Based off [bloc97's colab](https://github.com/bloc97/CrossAttentionControl). + +* `a ("fluffy cat").swap("smiling dog") eating a hotdog`. + * quotes optional: `a (fluffy cat).swap(smiling dog) eating a hotdog`. + * for single word substitutions parentheses are also optional: `a cat.swap(dog) eating a hotdog`. +* Supports options `s_start`, `s_end`, `t_start`, `t_end` (each 0-1) loosely corresponding to bloc97's `prompt_edit_spatial_start/_end` and `prompt_edit_tokens_start/_end` but with the math swapped to make it easier to intuitively understand. + * Example usage:`a (cat).swap(dog, s_end=0.3) eating a hotdog` - the `s_end` argument means that the "spatial" (self-attention) edit will stop having any effect after 30% (=0.3) of the steps have been done, leaving Stable Diffusion with 70% of the steps where it is free to decide for itself how to reshape the cat-form into a dog form. + * The numbers represent a percentage through the step sequence where the edits should happen. 0 means the start (noisy starting image), 1 is the end (final image). + * For img2img, the step sequence does not start at 0 but instead at (1-strength) - so if strength is 0.7, s_start and s_end must both be greater than 0.3 (1-0.7) to have any effect. +* Convenience option `shape_freedom` (0-1) to specify how much "freedom" Stable Diffusion should have to change the shape of the subject being swapped. + * `a (cat).swap(dog, shape_freedom=0.5) eating a hotdog`. + +### Escaping parantheses () and speech marks "" + +If the model you are using has parentheses () or speech marks "" as part of its syntax, you will need to "escape" these using a backslash, so that`(my_keyword)` becomes `\(my_keyword\)`. Otherwise, the prompt parser will attempt to interpret the parentheses as part of the prompt syntax and it will get confused. + ## **Prompt Blending** You may blend together different sections of the prompt to explore the diff --git a/ldm/generate.py b/ldm/generate.py index 3ede1710e1..43ed28eecd 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -1,5 +1,5 @@ # Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein) - +import pyparsing # Derived from source code carrying the following copyrights # Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors @@ -24,6 +24,7 @@ from PIL import Image, ImageOps from torch import nn from pytorch_lightning import seed_everything, logging +from ldm.invoke.prompt_parser import PromptParser from ldm.util import instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.plms import PLMSSampler @@ -32,7 +33,7 @@ from ldm.invoke.pngwriter import PngWriter from ldm.invoke.args import metadata_from_png from ldm.invoke.image_util import InitImageResizer from ldm.invoke.devices import choose_torch_device, choose_precision -from ldm.invoke.conditioning import get_uc_and_c +from ldm.invoke.conditioning import get_uc_and_c_and_ec from ldm.invoke.model_cache import ModelCache from ldm.invoke.seamless import configure_model_padding from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale @@ -404,7 +405,7 @@ class Generate: mask_image = None try: - uc, c = get_uc_and_c( + uc, c, extra_conditioning_info = get_uc_and_c_and_ec( prompt, model =self.model, skip_normalize=skip_normalize, log_tokens =self.log_tokenization @@ -448,7 +449,7 @@ class Generate: sampler=self.sampler, steps=steps, cfg_scale=cfg_scale, - conditioning=(uc, c), + conditioning=(uc, c, extra_conditioning_info), ddim_eta=ddim_eta, image_callback=image_callback, # called after the final image is generated step_callback=step_callback, # called after each intermediate image is generated @@ -481,14 +482,14 @@ class Generate: save_original = save_original, image_callback = image_callback) - except RuntimeError as e: - print(traceback.format_exc(), file=sys.stderr) - print('>> Could not generate image.') except KeyboardInterrupt: if catch_interrupts: print('**Interrupted** Partial results will be returned.') else: raise KeyboardInterrupt + except RuntimeError as e: + print(traceback.format_exc(), file=sys.stderr) + print('>> Could not generate image.') toc = time.time() print('>> Usage stats:') @@ -553,7 +554,8 @@ class Generate: image = Image.open(image_path) # used by multiple postfixers - uc, c = get_uc_and_c( + # todo: cross-attention control + uc, c, _ = get_uc_and_c_and_ec( prompt, model =self.model, skip_normalize=opt.skip_normalize, log_tokens =opt.log_tokenization diff --git a/ldm/invoke/args.py b/ldm/invoke/args.py index b24620df15..abde269acf 100644 --- a/ldm/invoke/args.py +++ b/ldm/invoke/args.py @@ -92,7 +92,7 @@ import copy import base64 import functools import ldm.invoke.pngwriter -from ldm.invoke.conditioning import split_weighted_subprompts +from ldm.invoke.prompt_parser import split_weighted_subprompts SAMPLER_CHOICES = [ 'ddim', diff --git a/ldm/invoke/conditioning.py b/ldm/invoke/conditioning.py index fedd965a2c..7c095de7b7 100644 --- a/ldm/invoke/conditioning.py +++ b/ldm/invoke/conditioning.py @@ -4,107 +4,166 @@ weighted subprompts. Useful function exports: -get_uc_and_c() get the conditioned and unconditioned latent +get_uc_and_c_and_ec() get the conditioned and unconditioned latent, and edited conditioning if we're doing cross-attention control split_weighted_subpromopts() split subprompts, normalize and weight them log_tokenization() print out colour-coded tokens and warn if truncated ''' import re +from difflib import SequenceMatcher +from typing import Union + import torch -def get_uc_and_c(prompt, model, log_tokens=False, skip_normalize=False): +from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \ + CrossAttentionControlledFragment, CrossAttentionControlSubstitute, CrossAttentionControlAppend, Fragment +from ..models.diffusion.cross_attention_control import CrossAttentionControl +from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent +from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder + + +def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_normalize=False): + # Extract Unconditioned Words From Prompt unconditioned_words = '' unconditional_regex = r'\[(.*?)\]' - unconditionals = re.findall(unconditional_regex, prompt) + unconditionals = re.findall(unconditional_regex, prompt_string_uncleaned) if len(unconditionals) > 0: unconditioned_words = ' '.join(unconditionals) # Remove Unconditioned Words From Prompt unconditional_regex_compile = re.compile(unconditional_regex) - clean_prompt = unconditional_regex_compile.sub(' ', prompt) - prompt = re.sub(' +', ' ', clean_prompt) + clean_prompt = unconditional_regex_compile.sub(' ', prompt_string_uncleaned) + prompt_string_cleaned = re.sub(' +', ' ', clean_prompt) + else: + prompt_string_cleaned = prompt_string_uncleaned - uc = model.get_learned_conditioning([unconditioned_words]) + pp = PromptParser() - # get weighted sub-prompts - weighted_subprompts = split_weighted_subprompts( - prompt, skip_normalize + parsed_prompt: Union[FlattenedPrompt, Blend] = None + legacy_blend: Blend = pp.parse_legacy_blend(prompt_string_cleaned) + if legacy_blend is not None: + parsed_prompt = legacy_blend + else: + # we don't support conjunctions for now + parsed_prompt = pp.parse_conjunction(prompt_string_cleaned).prompts[0] + + parsed_negative_prompt: FlattenedPrompt = pp.parse_conjunction(unconditioned_words).prompts[0] + print(f">> Parsed prompt to {parsed_prompt}") + + conditioning = None + cac_args:CrossAttentionControl.Arguments = None + + if type(parsed_prompt) is Blend: + blend: Blend = parsed_prompt + embeddings_to_blend = None + for flattened_prompt in blend.prompts: + this_embedding, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt, log_tokens=log_tokens) + embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat( + (embeddings_to_blend, this_embedding)) + conditioning = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0), + blend.weights, + normalize=blend.normalize_weights) + else: + flattened_prompt: FlattenedPrompt = parsed_prompt + wants_cross_attention_control = type(flattened_prompt) is not Blend \ + and any([issubclass(type(x), CrossAttentionControlledFragment) for x in flattened_prompt.children]) + if wants_cross_attention_control: + original_prompt = FlattenedPrompt() + edited_prompt = FlattenedPrompt() + # for name, a0, a1, b0, b1 in edit_opcodes: only name == 'equal' is currently parsed + original_token_count = 0 + edited_token_count = 0 + edit_opcodes = [] + edit_options = [] + for fragment in flattened_prompt.children: + if type(fragment) is CrossAttentionControlSubstitute: + original_prompt.append(fragment.original) + edited_prompt.append(fragment.edited) + + to_replace_token_count = get_tokens_length(model, fragment.original) + replacement_token_count = get_tokens_length(model, fragment.edited) + edit_opcodes.append(('replace', + original_token_count, original_token_count + to_replace_token_count, + edited_token_count, edited_token_count + replacement_token_count + )) + original_token_count += to_replace_token_count + edited_token_count += replacement_token_count + edit_options.append(fragment.options) + #elif type(fragment) is CrossAttentionControlAppend: + # edited_prompt.append(fragment.fragment) + else: + # regular fragment + original_prompt.append(fragment) + edited_prompt.append(fragment) + + count = get_tokens_length(model, [fragment]) + edit_opcodes.append(('equal', original_token_count, original_token_count+count, edited_token_count, edited_token_count+count)) + edit_options.append(None) + original_token_count += count + edited_token_count += count + original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, original_prompt, log_tokens=log_tokens) + # naïvely building a single edited_embeddings like this disregards the effects of changing the absolute location of + # subsequent tokens when there is >1 edit and earlier edits change the total token count. + # eg "a cat.swap(smiling dog, s_start=0.5) eating a hotdog.swap(pizza)" - when the 'pizza' edit is active but the + # 'cat' edit is not, the 'pizza' feature vector will nevertheless be affected by the introduction of the extra + # token 'smiling' in the inactive 'cat' edit. + # todo: build multiple edited_embeddings, one for each edit, and pass just the edited fragments through to the CrossAttentionControl functions + edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, edited_prompt, log_tokens=log_tokens) + + conditioning = original_embeddings + edited_conditioning = edited_embeddings + print('got edit_opcodes', edit_opcodes, 'options', edit_options) + cac_args = CrossAttentionControl.Arguments( + edited_conditioning = edited_conditioning, + edit_opcodes = edit_opcodes, + edit_options = edit_options + ) + else: + conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt, log_tokens=log_tokens) + + + unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, parsed_negative_prompt, log_tokens=log_tokens) + return ( + unconditioning, conditioning, InvokeAIDiffuserComponent.ExtraConditioningInfo( + cross_attention_control_args=cac_args + ) ) - if len(weighted_subprompts) > 1: - # i dont know if this is correct.. but it works - c = torch.zeros_like(uc) - # normalize each "sub prompt" and add it - for subprompt, weight in weighted_subprompts: - log_tokenization(subprompt, model, log_tokens, weight) - c = torch.add( - c, - model.get_learned_conditioning([subprompt]), - alpha=weight, - ) - else: # just standard 1 prompt - log_tokenization(prompt, model, log_tokens, 1) - c = model.get_learned_conditioning([prompt]) - uc = model.get_learned_conditioning([unconditioned_words]) - return (uc, c) -def split_weighted_subprompts(text, skip_normalize=False)->list: - """ - grabs all text up to the first occurrence of ':' - uses the grabbed text as a sub-prompt, and takes the value following ':' as weight - if ':' has no value defined, defaults to 1.0 - repeats until no text remaining - """ - prompt_parser = re.compile(""" - (?P # capture group for 'prompt' - (?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:' - ) # end 'prompt' - (?: # non-capture group - :+ # match one or more ':' characters - (?P # capture group for 'weight' - -?\d+(?:\.\d+)? # match positive or negative integer or decimal number - )? # end weight capture group, make optional - \s* # strip spaces after weight - | # OR - $ # else, if no ':' then match end of line - ) # end non-capture group - """, re.VERBOSE) - parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float( - match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)] - if skip_normalize: - return parsed_prompts - weight_sum = sum(map(lambda x: x[1], parsed_prompts)) - if weight_sum == 0: - print( - "Warning: Subprompt weights add up to zero. Discarding and using even weights instead.") - equal_weight = 1 / max(len(parsed_prompts), 1) - return [(x[0], equal_weight) for x in parsed_prompts] - return [(x[0], x[1] / weight_sum) for x in parsed_prompts] +def build_token_edit_opcodes(original_tokens, edited_tokens): + original_tokens = original_tokens.cpu().numpy()[0] + edited_tokens = edited_tokens.cpu().numpy()[0] + + return SequenceMatcher(None, original_tokens, edited_tokens).get_opcodes() + +def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt, log_tokens: bool=False): + if type(flattened_prompt) is not FlattenedPrompt: + raise Exception(f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead") + fragments = [x.text for x in flattened_prompt.children] + weights = [x.weight for x in flattened_prompt.children] + embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True, fragment_weights=[weights]) + if not flattened_prompt.is_empty and log_tokens: + start_token = model.cond_stage_model.tokenizer.bos_token_id + end_token = model.cond_stage_model.tokenizer.eos_token_id + tokens_list = tokens[0].tolist() + if tokens_list[0] == start_token: + tokens_list[0] = '' + try: + first_end_token_index = tokens_list.index(end_token) + tokens_list[first_end_token_index] = '' + tokens_list = tokens_list[:first_end_token_index+1] + except ValueError: + pass + + print(f">> Prompt fragments {fragments}, tokenized to \n{tokens_list}") + + return embeddings, tokens + +def get_tokens_length(model, fragments: list[Fragment]): + fragment_texts = [x.text for x in fragments] + tokens = model.cond_stage_model.get_tokens(fragment_texts, include_start_and_end_markers=False) + return sum([len(x) for x in tokens]) + -# shows how the prompt is tokenized -# usually tokens have '' to indicate end-of-word, -# but for readability it has been replaced with ' ' -def log_tokenization(text, model, log=False, weight=1): - if not log: - return - tokens = model.cond_stage_model.tokenizer._tokenize(text) - tokenized = "" - discarded = "" - usedTokens = 0 - totalTokens = len(tokens) - for i in range(0, totalTokens): - token = tokens[i].replace('', ' ') - # alternate color - s = (usedTokens % 6) + 1 - if i < model.cond_stage_model.max_length: - tokenized = tokenized + f"\x1b[0;3{s};40m{token}" - usedTokens += 1 - else: # over max token length - discarded = discarded + f"\x1b[0;3{s};40m{token}" - print(f"\n>> Tokens ({usedTokens}), Weight ({weight:.2f}):\n{tokenized}\x1b[0m") - if discarded != "": - print( - f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m" - ) diff --git a/ldm/invoke/generator/img2img.py b/ldm/invoke/generator/img2img.py index 613f1aca31..79b943024c 100644 --- a/ldm/invoke/generator/img2img.py +++ b/ldm/invoke/generator/img2img.py @@ -10,6 +10,7 @@ from PIL import Image from ldm.invoke.devices import choose_autocast from ldm.invoke.generator.base import Generator from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent class Img2Img(Generator): def __init__(self, model, precision): @@ -38,7 +39,7 @@ class Img2Img(Generator): ) # move to latent space t_enc = int(strength * steps) - uc, c = conditioning + uc, c, extra_conditioning_info = conditioning def make_image(x_T): # encode (scaled latent) @@ -55,7 +56,9 @@ class Img2Img(Generator): img_callback = step_callback, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc, - init_latent = self.init_latent, # changes how noising is performed in ksampler + init_latent = self.init_latent, # changes how noising is performed in ksampler + extra_conditioning_info = extra_conditioning_info, + all_timesteps_count = steps ) return self.sample_to_image(samples) diff --git a/ldm/invoke/generator/inpaint.py b/ldm/invoke/generator/inpaint.py index 8fbcf249aa..34d4f209fc 100644 --- a/ldm/invoke/generator/inpaint.py +++ b/ldm/invoke/generator/inpaint.py @@ -73,7 +73,8 @@ class Inpaint(Img2Img): ) # move to latent space t_enc = int(strength * steps) - uc, c = conditioning + # todo: support cross-attention control + uc, c, _ = conditioning print(f">> target t_enc is {t_enc} steps") diff --git a/ldm/invoke/generator/txt2img.py b/ldm/invoke/generator/txt2img.py index 86eb679c04..ba49d2ef55 100644 --- a/ldm/invoke/generator/txt2img.py +++ b/ldm/invoke/generator/txt2img.py @@ -5,6 +5,8 @@ ldm.invoke.generator.txt2img inherits from ldm.invoke.generator import torch import numpy as np from ldm.invoke.generator.base import Generator +from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent + class Txt2Img(Generator): def __init__(self, model, precision): @@ -19,7 +21,7 @@ class Txt2Img(Generator): kwargs are 'width' and 'height' """ self.perlin = perlin - uc, c = conditioning + uc, c, extra_conditioning_info = conditioning @torch.no_grad() def make_image(x_T): @@ -43,6 +45,7 @@ class Txt2Img(Generator): verbose = False, unconditional_guidance_scale = cfg_scale, unconditional_conditioning = uc, + extra_conditioning_info = extra_conditioning_info, eta = ddim_eta, img_callback = step_callback, threshold = threshold, diff --git a/ldm/invoke/generator/txt2img2img.py b/ldm/invoke/generator/txt2img2img.py index 4e68d94875..9fad2d80e1 100644 --- a/ldm/invoke/generator/txt2img2img.py +++ b/ldm/invoke/generator/txt2img2img.py @@ -7,6 +7,7 @@ import numpy as np import math from ldm.invoke.generator.base import Generator from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent class Txt2Img2Img(Generator): @@ -22,7 +23,7 @@ class Txt2Img2Img(Generator): Return value depends on the seed at the time you call it kwargs are 'width' and 'height' """ - uc, c = conditioning + uc, c, extra_conditioning_info = conditioning @torch.no_grad() def make_image(x_T): @@ -60,7 +61,8 @@ class Txt2Img2Img(Generator): unconditional_guidance_scale = cfg_scale, unconditional_conditioning = uc, eta = ddim_eta, - img_callback = step_callback + img_callback = step_callback, + extra_conditioning_info = extra_conditioning_info ) print( @@ -94,6 +96,8 @@ class Txt2Img2Img(Generator): img_callback = step_callback, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc, + extra_conditioning_info=extra_conditioning_info, + all_timesteps_count=steps ) if self.free_gpu_mem: diff --git a/ldm/invoke/prompt_parser.py b/ldm/invoke/prompt_parser.py new file mode 100644 index 0000000000..6709f48066 --- /dev/null +++ b/ldm/invoke/prompt_parser.py @@ -0,0 +1,680 @@ +import string +from typing import Union, Optional +import re +import pyparsing as pp + +class Prompt(): + """ + Mid-level structure for storing the tree-like result of parsing a prompt. A Prompt may not represent the whole of + the singular user-defined "prompt string" (although it can) - for example, if the user specifies a Blend, the objects + that are to be blended together are stored individuall as Prompt objects. + + Nesting makes this object not suitable for directly tokenizing; instead call flatten() on the containing Conjunction + to produce a FlattenedPrompt. + """ + def __init__(self, parts: list): + for c in parts: + if type(c) is not Attention and not issubclass(type(c), BaseFragment) and type(c) is not pp.ParseResults: + raise PromptParser.ParsingException(f"Prompt cannot contain {type(c).__name__} {c}, only {BaseFragment.__subclasses__()} are allowed") + self.children = parts + def __repr__(self): + return f"Prompt:{self.children}" + def __eq__(self, other): + return type(other) is Prompt and other.children == self.children + +class BaseFragment: + pass + +class FlattenedPrompt(): + """ + A Prompt that has been passed through flatten(). Its children can be readily tokenized. + """ + def __init__(self, parts: list=[]): + self.children = [] + for part in parts: + self.append(part) + + def append(self, fragment: Union[list, BaseFragment, tuple]): + # verify type correctness + if type(fragment) is list: + for x in fragment: + self.append(x) + elif issubclass(type(fragment), BaseFragment): + self.children.append(fragment) + elif type(fragment) is tuple: + # upgrade tuples to Fragments + if type(fragment[0]) is not str or (type(fragment[1]) is not float and type(fragment[1]) is not int): + raise PromptParser.ParsingException( + f"FlattenedPrompt cannot contain {fragment}, only Fragments or (str, float) tuples are allowed") + self.children.append(Fragment(fragment[0], fragment[1])) + else: + raise PromptParser.ParsingException( + f"FlattenedPrompt cannot contain {fragment}, only Fragments or (str, float) tuples are allowed") + + @property + def is_empty(self): + return len(self.children) == 0 or \ + (len(self.children) == 1 and len(self.children[0].text) == 0) + + def __repr__(self): + return f"FlattenedPrompt:{self.children}" + def __eq__(self, other): + return type(other) is FlattenedPrompt and other.children == self.children + + +class Fragment(BaseFragment): + """ + A Fragment is a chunk of plain text and an optional weight. The text should be passed as-is to the CLIP tokenizer. + """ + def __init__(self, text: str, weight: float=1): + assert(type(text) is str) + if '\\"' in text or '\\(' in text or '\\)' in text: + #print("Fragment converting escaped \( \) \\\" into ( ) \"") + text = text.replace('\\(', '(').replace('\\)', ')').replace('\\"', '"') + self.text = text + self.weight = float(weight) + + def __repr__(self): + return "Fragment:'"+self.text+"'@"+str(self.weight) + def __eq__(self, other): + return type(other) is Fragment \ + and other.text == self.text \ + and other.weight == self.weight + +class Attention(): + """ + Nestable weight control for fragments. Each object in the children array may in turn be an Attention object; + weights should be considered to accumulate as the tree is traversed to deeper levels of nesting. + + Do not traverse directly; instead obtain a FlattenedPrompt by calling Flatten() on a top-level Conjunction object. + """ + def __init__(self, weight: float, children: list): + self.weight = weight + self.children = children + #print(f"A: requested attention '{children}' to {weight}") + + def __repr__(self): + return f"Attention:'{self.children}' @ {self.weight}" + def __eq__(self, other): + return type(other) is Attention and other.weight == self.weight and other.fragment == self.fragment + +class CrossAttentionControlledFragment(BaseFragment): + pass + +class CrossAttentionControlSubstitute(CrossAttentionControlledFragment): + """ + A Cross-Attention Controlled ('prompt2prompt') fragment, for use inside a Prompt, Attention, or FlattenedPrompt. + Representing an "original" word sequence that supplies feature vectors for an initial diffusion operation, and an + "edited" word sequence, to which the attention maps produced by the "original" word sequence are applied. Intuitively, + the result should be an "edited" image that looks like the "original" image with concepts swapped. + + eg "a cat sitting on a car" (original) -> "a smiling dog sitting on a car" (edited): the edited image should look + almost exactly the same as the original, but with a smiling dog rendered in place of the cat. The + CrossAttentionControlSubstitute object representing this swap may be confined to the tokens being swapped: + CrossAttentionControlSubstitute(original=[Fragment('cat')], edited=[Fragment('dog')]) + or it may represent a larger portion of the token sequence: + CrossAttentionControlSubstitute(original=[Fragment('a cat sitting on a car')], + edited=[Fragment('a smiling dog sitting on a car')]) + + In either case expect it to be embedded in a Prompt or FlattenedPrompt: + FlattenedPrompt([ + Fragment('a'), + CrossAttentionControlSubstitute(original=[Fragment('cat')], edited=[Fragment('dog')]), + Fragment('sitting on a car') + ]) + """ + def __init__(self, original: Union[Fragment, list], edited: Union[Fragment, list], options: dict=None): + self.original = original + self.edited = edited + + default_options = { + 's_start': 0.0, + 's_end': 0.206, # ~= shape_freedom=0.5 + 't_start': 0.0, + 't_end': 1.0 + } + merged_options = default_options + if options is not None: + shape_freedom = options.pop('shape_freedom', None) + if shape_freedom is not None: + # high shape freedom = SD can do what it wants with the shape of the object + # high shape freedom => s_end = 0 + # low shape freedom => s_end = 1 + # shape freedom is in a "linear" space, while noticeable changes to s_end are typically closer around 0, + # and there is very little perceptible difference as s_end increases above 0.5 + # so for shape_freedom = 0.5 we probably want s_end to be 0.2 + # -> cube root and subtract from 1.0 + merged_options['s_end'] = 1.0 - shape_freedom ** (1. / 3.) + print('converted shape_freedom argument to', merged_options) + merged_options.update(options) + + self.options = merged_options + + def __repr__(self): + return f"CrossAttentionControlSubstitute:({self.original}->{self.edited} ({self.options})" + def __eq__(self, other): + return type(other) is CrossAttentionControlSubstitute \ + and other.original == self.original \ + and other.edited == self.edited \ + and other.options == self.options + + +class CrossAttentionControlAppend(CrossAttentionControlledFragment): + def __init__(self, fragment: Fragment): + self.fragment = fragment + def __repr__(self): + return "CrossAttentionControlAppend:",self.fragment + def __eq__(self, other): + return type(other) is CrossAttentionControlAppend \ + and other.fragment == self.fragment + + + +class Conjunction(): + """ + Storage for one or more Prompts or Blends, each of which is to be separately diffused and then the results merged + by weighted sum in latent space. + """ + def __init__(self, prompts: list, weights: list = None): + # force everything to be a Prompt + #print("making conjunction with", parts) + self.prompts = [x if (type(x) is Prompt + or type(x) is Blend + or type(x) is FlattenedPrompt) + else Prompt(x) for x in prompts] + self.weights = [1.0]*len(self.prompts) if weights is None else list(weights) + if len(self.weights) != len(self.prompts): + raise PromptParser.ParsingException(f"while parsing Conjunction: mismatched parts/weights counts {prompts}, {weights}") + self.type = 'AND' + + def __repr__(self): + return f"Conjunction:{self.prompts} | weights {self.weights}" + def __eq__(self, other): + return type(other) is Conjunction \ + and other.prompts == self.prompts \ + and other.weights == self.weights + + +class Blend(): + """ + Stores a Blend of multiple Prompts. To apply, build feature vectors for each of the child Prompts and then perform a + weighted blend of the feature vectors to produce a single feature vector that is effectively a lerp between the + Prompts. + """ + def __init__(self, prompts: list, weights: list[float], normalize_weights: bool=True): + #print("making Blend with prompts", prompts, "and weights", weights) + if len(prompts) != len(weights): + raise PromptParser.ParsingException(f"while parsing Blend: mismatched prompts/weights counts {prompts}, {weights}") + for c in prompts: + if type(c) is not Prompt and type(c) is not FlattenedPrompt: + raise(PromptParser.ParsingException(f"{type(c)} cannot be added to a Blend, only Prompts or FlattenedPrompts")) + # upcast all lists to Prompt objects + self.prompts = [x if (type(x) is Prompt or type(x) is FlattenedPrompt) + else Prompt(x) for x in prompts] + self.prompts = prompts + self.weights = weights + self.normalize_weights = normalize_weights + + def __repr__(self): + return f"Blend:{self.prompts} | weights {' ' if self.normalize_weights else '(non-normalized) '}{self.weights}" + def __eq__(self, other): + return other.__repr__() == self.__repr__() + + +class PromptParser(): + + class ParsingException(Exception): + pass + + def __init__(self, attention_plus_base=1.1, attention_minus_base=0.9): + + self.conjunction, self.prompt = build_parser_syntax(attention_plus_base, attention_minus_base) + + + def parse_conjunction(self, prompt: str) -> Conjunction: + ''' + :param prompt: The prompt string to parse + :return: a Conjunction representing the parsed results. + ''' + #print(f"!!parsing '{prompt}'") + + if len(prompt.strip()) == 0: + return Conjunction(prompts=[FlattenedPrompt([('', 1.0)])], weights=[1.0]) + + root = self.conjunction.parse_string(prompt) + #print(f"'{prompt}' parsed to root", root) + #fused = fuse_fragments(parts) + #print("fused to", fused) + + return self.flatten(root[0]) + + def parse_legacy_blend(self, text: str) -> Optional[Blend]: + weighted_subprompts = split_weighted_subprompts(text, skip_normalize=False) + if len(weighted_subprompts) == 1: + return None + strings = [x[0] for x in weighted_subprompts] + weights = [x[1] for x in weighted_subprompts] + + parsed_conjunctions = [self.parse_conjunction(x) for x in strings] + flattened_prompts = [x.prompts[0] for x in parsed_conjunctions] + + return Blend(prompts=flattened_prompts, weights=weights, normalize_weights=True) + + + def flatten(self, root: Conjunction) -> Conjunction: + """ + Flattening a Conjunction traverses all of the nested tree-like structures in each of its Prompts or Blends, + producing from each of these walks a linear sequence of Fragment or CrossAttentionControlSubstitute objects + that can be readily tokenized without the need to walk a complex tree structure. + + :param root: The Conjunction to flatten. + :return: A Conjunction containing the result of flattening each of the prompts in the passed-in root. + """ + + #print("flattening", root) + + def fuse_fragments(items): + # print("fusing fragments in ", items) + result = [] + for x in items: + if type(x) is CrossAttentionControlSubstitute: + original_fused = fuse_fragments(x.original) + edited_fused = fuse_fragments(x.edited) + result.append(CrossAttentionControlSubstitute(original_fused, edited_fused, options=x.options)) + else: + last_weight = result[-1].weight \ + if (len(result) > 0 and not issubclass(type(result[-1]), CrossAttentionControlledFragment)) \ + else None + this_text = x.text + this_weight = x.weight + if last_weight is not None and last_weight == this_weight: + last_text = result[-1].text + result[-1] = Fragment(last_text + ' ' + this_text, last_weight) + else: + result.append(x) + return result + + def flatten_internal(node, weight_scale, results, prefix): + #print(prefix + "flattening", node, "...") + if type(node) is pp.ParseResults: + for x in node: + results = flatten_internal(x, weight_scale, results, prefix+' pr ') + #print(prefix, " ParseResults expanded, results is now", results) + elif type(node) is Attention: + # if node.weight < 1: + # todo: inject a blend when flattening attention with weight <1" + for index,c in enumerate(node.children): + results = flatten_internal(c, weight_scale * node.weight, results, prefix + f" att{index} ") + elif type(node) is Fragment: + results += [Fragment(node.text, node.weight*weight_scale)] + elif type(node) is CrossAttentionControlSubstitute: + original = flatten_internal(node.original, weight_scale, [], prefix + ' CAo ') + edited = flatten_internal(node.edited, weight_scale, [], prefix + ' CAe ') + results += [CrossAttentionControlSubstitute(original, edited, options=node.options)] + elif type(node) is Blend: + flattened_subprompts = [] + #print(" flattening blend with prompts", node.prompts, "weights", node.weights) + for prompt in node.prompts: + # prompt is a list + flattened_subprompts = flatten_internal(prompt, weight_scale, flattened_subprompts, prefix+'B ') + results += [Blend(prompts=flattened_subprompts, weights=node.weights, normalize_weights=node.normalize_weights)] + elif type(node) is Prompt: + #print(prefix + "about to flatten Prompt with children", node.children) + flattened_prompt = [] + for child in node.children: + flattened_prompt = flatten_internal(child, weight_scale, flattened_prompt, prefix+'P ') + results += [FlattenedPrompt(parts=fuse_fragments(flattened_prompt))] + #print(prefix + "after flattening Prompt, results is", results) + else: + raise PromptParser.ParsingException(f"unhandled node type {type(node)} when flattening {node}") + #print(prefix + "-> after flattening", type(node).__name__, "results is", results) + return results + + + flattened_parts = [] + for part in root.prompts: + flattened_parts += flatten_internal(part, 1.0, [], ' C| ') + + #print("flattened to", flattened_parts) + + weights = root.weights + return Conjunction(flattened_parts, weights) + + + +def build_parser_syntax(attention_plus_base: float, attention_minus_base: float): + + lparen = pp.Literal("(").suppress() + rparen = pp.Literal(")").suppress() + quotes = pp.Literal('"').suppress() + comma = pp.Literal(",").suppress() + + # accepts int or float notation, always maps to float + number = pp.pyparsing_common.real | \ + pp.Combine(pp.Optional("-")+pp.Word(pp.nums)).set_parse_action(pp.token_map(float)) + greedy_word = pp.Word(pp.printables, exclude_chars=string.whitespace).set_name('greedy_word') + + attention = pp.Forward() + quoted_fragment = pp.Forward() + parenthesized_fragment = pp.Forward() + cross_attention_substitute = pp.Forward() + prompt_part = pp.Forward() + + def make_text_fragment(x): + #print("### making fragment for", x) + if type(x) is str: + return Fragment(x) + elif type(x) is pp.ParseResults or type(x) is list: + #print(f'converting {type(x).__name__} to Fragment') + return Fragment(' '.join([s for s in x])) + else: + raise PromptParser.ParsingException("Cannot make fragment from " + str(x)) + + def build_escaped_word_parser(escaped_chars_to_ignore: str): + terms = [] + for c in escaped_chars_to_ignore: + terms.append(pp.Literal('\\'+c)) + terms.append( + #pp.CharsNotIn(string.whitespace + escaped_chars_to_ignore, exact=1) + pp.Word(pp.printables, exclude_chars=string.whitespace + escaped_chars_to_ignore) + ) + return pp.Combine(pp.OneOrMore( + pp.MatchFirst(terms) + )) + + def build_escaped_word_parser_charbychar(escaped_chars_to_ignore: str): + escapes = [] + for c in escaped_chars_to_ignore: + escapes.append(pp.Literal('\\'+c)) + return pp.Combine(pp.OneOrMore( + pp.MatchFirst(escapes + [pp.CharsNotIn( + string.whitespace + escaped_chars_to_ignore, + exact=1 + )]) + )) + + + + def parse_fragment_str(x, in_quotes: bool=False, in_parens: bool=False): + #print(f"parsing fragment string \"{x}\"") + fragment_string = x[0] + if len(fragment_string.strip()) == 0: + return Fragment('') + + if in_quotes: + # escape unescaped quotes + fragment_string = fragment_string.replace('"', '\\"') + + #fragment_parser = pp.Group(pp.OneOrMore(attention | cross_attention_substitute | (greedy_word.set_parse_action(make_text_fragment)))) + result = pp.Group(pp.MatchFirst([ + pp.OneOrMore(prompt_part | quoted_fragment), + pp.Empty().set_parse_action(make_text_fragment) + pp.StringEnd() + ])).set_name('rr').set_debug(False).parse_string(fragment_string) + #result = (pp.OneOrMore(attention | unquoted_word) + pp.StringEnd()).parse_string(x[0]) + #print("parsed to", result) + return result + + quoted_fragment << pp.QuotedString(quote_char='"', esc_char=None, esc_quote='\\"') + quoted_fragment.set_parse_action(lambda x: parse_fragment_str(x, in_quotes=True)).set_name('quoted_fragment') + + escaped_quote = pp.Literal('\\"')#.set_parse_action(lambda x: '"') + escaped_lparen = pp.Literal('\\(')#.set_parse_action(lambda x: '(') + escaped_rparen = pp.Literal('\\)')#.set_parse_action(lambda x: ')') + escaped_backslash = pp.Literal('\\\\')#.set_parse_action(lambda x: '"') + + def not_ends_with_swap(x): + #print("trying to match:", x) + return not x[0].endswith('.swap') + + unquoted_fragment = pp.Combine(pp.OneOrMore( + escaped_rparen | escaped_lparen | escaped_quote | escaped_backslash | + pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"()'))) + unquoted_fragment.set_parse_action(make_text_fragment).set_name('unquoted_fragment').set_debug(False) + #print(unquoted_fragment.parse_string("cat.swap(dog)")) + + parenthesized_fragment << pp.Or([ + (lparen + quoted_fragment.copy().set_parse_action(lambda x: parse_fragment_str(x, in_quotes=True)).set_debug(False) + rparen).set_name('-quoted_paren_internal').set_debug(False), + (lparen + rparen).set_parse_action(lambda x: make_text_fragment('')).set_name('-()').set_debug(False), + (lparen + pp.Combine(pp.OneOrMore( + escaped_quote | escaped_lparen | escaped_rparen | escaped_backslash | + pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"()') | + pp.Word(string.whitespace) + )).set_name('--combined').set_parse_action(lambda x: parse_fragment_str(x, in_parens=True)).set_debug(False) + rparen)]).set_name('-unquoted_paren_internal').set_debug(False) + parenthesized_fragment.set_name('parenthesized_fragment').set_debug(False) + + debug_attention = False + # attention control of the form (phrase)+ / (phrase)+ / (phrase) + # phrase can be multiple words, can have multiple +/- signs to increase the effect or type a floating point or integer weight + attention_with_parens = pp.Forward() + attention_without_parens = pp.Forward() + + attention_with_parens_foot = (number | pp.Word('+') | pp.Word('-'))\ + .set_name("attention_foot")\ + .set_debug(False) + attention_with_parens <<= pp.Group( + lparen + + pp.ZeroOrMore(quoted_fragment | attention_with_parens | parenthesized_fragment | cross_attention_substitute | attention_without_parens | + (pp.Empty() + build_escaped_word_parser_charbychar('()')).set_name('undecorated_word').set_debug(debug_attention)#.set_parse_action(lambda t: t[0]) + ) + + rparen + attention_with_parens_foot) + attention_with_parens.set_name('attention_with_parens').set_debug(debug_attention) + + attention_without_parens_foot = pp.Or(pp.Word('+') | pp.Word('-')).set_name('attention_without_parens_foots') + attention_without_parens <<= pp.Group( + (quoted_fragment.copy().set_name('attention_quoted_fragment_without_parens').set_debug(debug_attention) + attention_without_parens_foot) | + pp.Combine(build_escaped_word_parser_charbychar('()+-')).set_name('attention_word_without_parens').set_debug(debug_attention)#.set_parse_action(lambda x: print('escapéd', x)) + + attention_without_parens_foot)#.leave_whitespace() + attention_without_parens.set_name('attention_without_parens').set_debug(debug_attention) + + + attention << pp.MatchFirst([attention_with_parens, + attention_without_parens + ]) + attention.set_name('attention') + + def make_attention(x): + #print("entered make_attention with", x) + children = x[0][:-1] + weight_raw = x[0][-1] + weight = 1.0 + if type(weight_raw) is float or type(weight_raw) is int: + weight = weight_raw + elif type(weight_raw) is str: + base = attention_plus_base if weight_raw[0] == '+' else attention_minus_base + weight = pow(base, len(weight_raw)) + + #print("making Attention from", children, "with weight", weight) + + return Attention(weight=weight, children=[(Fragment(x) if type(x) is str else x) for x in children]) + + attention_with_parens.set_parse_action(make_attention) + attention_without_parens.set_parse_action(make_attention) + + #print("parsing test:", attention_with_parens.parse_string("mountain (man)1.1")) + + # cross-attention control + empty_string = ((lparen + rparen) | + pp.Literal('""').suppress() | + (lparen + pp.Literal('""').suppress() + rparen) + ).set_parse_action(lambda x: Fragment("")) + empty_string.set_name('empty_string') + + # cross attention control + debug_cross_attention_control = False + original_fragment = pp.Or([empty_string.set_debug(debug_cross_attention_control), + quoted_fragment.set_debug(debug_cross_attention_control), + parenthesized_fragment.set_debug(debug_cross_attention_control), + pp.Word(pp.printables, exclude_chars=string.whitespace + '.').set_parse_action(make_text_fragment) + pp.FollowedBy(".swap") + ]) + # support keyword=number arguments + cross_attention_option_keyword = pp.Or([pp.Keyword("s_start"), pp.Keyword("s_end"), pp.Keyword("t_start"), pp.Keyword("t_end"), pp.Keyword("shape_freedom")]) + cross_attention_option = pp.Group(cross_attention_option_keyword + pp.Literal("=").suppress() + number) + edited_fragment = pp.MatchFirst([ + lparen + + (quoted_fragment | + pp.Group(pp.OneOrMore(pp.Word(pp.printables, exclude_chars=string.whitespace + ',').set_parse_action(make_text_fragment))) + ) + + pp.Dict(pp.OneOrMore(comma + cross_attention_option)) + + rparen, + parenthesized_fragment + ]) + cross_attention_substitute << original_fragment + pp.Literal(".swap").suppress() + edited_fragment + + original_fragment.set_name('original_fragment').set_debug(debug_cross_attention_control) + edited_fragment.set_name('edited_fragment').set_debug(debug_cross_attention_control) + cross_attention_substitute.set_name('cross_attention_substitute').set_debug(debug_cross_attention_control) + + def make_cross_attention_substitute(x): + #print("making cacs for", x[0], "->", x[1], "with options", x.as_dict()) + #if len(x>2): + cacs = CrossAttentionControlSubstitute(x[0], x[1], options=x.as_dict()) + #print("made", cacs) + return cacs + cross_attention_substitute.set_parse_action(make_cross_attention_substitute) + + + # simple fragments of text + # use Or to match the longest + prompt_part << pp.MatchFirst([ + cross_attention_substitute, + attention, + unquoted_fragment, + lparen + unquoted_fragment + rparen # matches case where user has +(term) and just deletes the + + ]) + prompt_part.set_debug(False) + prompt_part.set_name("prompt_part") + + empty = ( + (lparen + pp.ZeroOrMore(pp.Word(string.whitespace)) + rparen) | + (quotes + pp.ZeroOrMore(pp.Word(string.whitespace)) + quotes)).set_debug(False).set_name('empty') + + # root prompt definition + prompt = (pp.OneOrMore(pp.Or([prompt_part, quoted_fragment, empty])) + pp.StringEnd()) \ + .set_parse_action(lambda x: Prompt(x)) + + #print("parsing test:", prompt.parse_string("spaced eyes--")) + #print("parsing test:", prompt.parse_string("eyes--")) + + # weighted blend of prompts + # ("promptA", "promptB").blend(a, b) where "promptA" and "promptB" are valid prompts and a and b are float or + # int weights. + # can specify more terms eg ("promptA", "promptB", "promptC").blend(a,b,c) + + def make_prompt_from_quoted_string(x): + #print(' got quoted prompt', x) + + x_unquoted = x[0][1:-1] + if len(x_unquoted.strip()) == 0: + # print(' b : just an empty string') + return Prompt([Fragment('')]) + # print(' b parsing ', c_unquoted) + x_parsed = prompt.parse_string(x_unquoted) + #print(" quoted prompt was parsed to", type(x_parsed),":", x_parsed) + return x_parsed[0] + + quoted_prompt = pp.dbl_quoted_string.set_parse_action(make_prompt_from_quoted_string) + quoted_prompt.set_name('quoted_prompt') + + debug_blend=False + blend_terms = pp.delimited_list(quoted_prompt).set_name('blend_terms').set_debug(debug_blend) + blend_weights = (pp.delimited_list(number) + pp.Optional(pp.Char(",").suppress() + "no_normalize")).set_name('blend_weights').set_debug(debug_blend) + blend = pp.Group(lparen + pp.Group(blend_terms) + rparen + + pp.Literal(".blend").suppress() + + lparen + pp.Group(blend_weights) + rparen).set_name('blend') + blend.set_debug(debug_blend) + + def make_blend(x): + prompts = x[0][0] + weights = x[0][1] + normalize = True + if weights[-1] == 'no_normalize': + normalize = False + weights = weights[:-1] + return Blend(prompts=prompts, weights=weights, normalize_weights=normalize) + + blend.set_parse_action(make_blend) + + conjunction_terms = blend_terms.copy().set_name('conjunction_terms') + conjunction_weights = blend_weights.copy().set_name('conjunction_weights') + conjunction_with_parens_and_quotes = pp.Group(lparen + pp.Group(conjunction_terms) + rparen + + pp.Literal(".and").suppress() + + lparen + pp.Optional(pp.Group(conjunction_weights)) + rparen).set_name('conjunction') + def make_conjunction(x): + parts_raw = x[0][0] + weights = x[0][1] if len(x[0])>1 else [1.0]*len(parts_raw) + parts = [part for part in parts_raw] + return Conjunction(parts, weights) + conjunction_with_parens_and_quotes.set_parse_action(make_conjunction) + + implicit_conjunction = pp.OneOrMore(blend | prompt).set_name('implicit_conjunction') + implicit_conjunction.set_parse_action(lambda x: Conjunction(x)) + + conjunction = conjunction_with_parens_and_quotes | implicit_conjunction + conjunction.set_debug(False) + + # top-level is a conjunction of one or more blends or prompts + return conjunction, prompt + + + +def split_weighted_subprompts(text, skip_normalize=False)->list: + """ + Legacy blend parsing. + + grabs all text up to the first occurrence of ':' + uses the grabbed text as a sub-prompt, and takes the value following ':' as weight + if ':' has no value defined, defaults to 1.0 + repeats until no text remaining + """ + prompt_parser = re.compile(""" + (?P # capture group for 'prompt' + (?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:' + ) # end 'prompt' + (?: # non-capture group + :+ # match one or more ':' characters + (?P # capture group for 'weight' + -?\d+(?:\.\d+)? # match positive or negative integer or decimal number + )? # end weight capture group, make optional + \s* # strip spaces after weight + | # OR + $ # else, if no ':' then match end of line + ) # end non-capture group + """, re.VERBOSE) + parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float( + match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)] + if skip_normalize: + return parsed_prompts + weight_sum = sum(map(lambda x: x[1], parsed_prompts)) + if weight_sum == 0: + print( + "Warning: Subprompt weights add up to zero. Discarding and using even weights instead.") + equal_weight = 1 / max(len(parsed_prompts), 1) + return [(x[0], equal_weight) for x in parsed_prompts] + return [(x[0], x[1] / weight_sum) for x in parsed_prompts] + + +# shows how the prompt is tokenized +# usually tokens have '' to indicate end-of-word, +# but for readability it has been replaced with ' ' +def log_tokenization(text, model, log=False, weight=1): + if not log: + return + tokens = model.cond_stage_model.tokenizer._tokenize(text) + tokenized = "" + discarded = "" + usedTokens = 0 + totalTokens = len(tokens) + for i in range(0, totalTokens): + token = tokens[i].replace('', 'x` ') + # alternate color + s = (usedTokens % 6) + 1 + if i < model.cond_stage_model.max_length: + tokenized = tokenized + f"\x1b[0;3{s};40m{token}" + usedTokens += 1 + else: # over max token length + discarded = discarded + f"\x1b[0;3{s};40m{token}" + print(f"\n>> Tokens ({usedTokens}), Weight ({weight:.2f}):\n{tokenized}\x1b[0m") + if discarded != "": + print( + f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m" + ) diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py new file mode 100644 index 0000000000..1e5b073a3d --- /dev/null +++ b/ldm/models/diffusion/cross_attention_control.py @@ -0,0 +1,238 @@ +from enum import Enum + +import torch + +# adapted from bloc97's CrossAttentionControl colab +# https://github.com/bloc97/CrossAttentionControl + +class CrossAttentionControl: + + class Arguments: + def __init__(self, edited_conditioning: torch.Tensor, edit_opcodes: list[tuple], edit_options: dict): + """ + :param edited_conditioning: if doing cross-attention control, the edited conditioning [1 x 77 x 768] + :param edit_opcodes: if doing cross-attention control, a list of difflib.SequenceMatcher-like opcodes describing how to map original conditioning tokens to edited conditioning tokens (only the 'equal' opcode is required) + :param edit_options: if doing cross-attention control, per-edit options. there should be 1 item in edit_options for each item in edit_opcodes. + """ + # todo: rewrite this to take embedding fragments rather than a single edited_conditioning vector + self.edited_conditioning = edited_conditioning + self.edit_opcodes = edit_opcodes + + if edited_conditioning is not None: + assert len(edit_opcodes) == len(edit_options), \ + "there must be 1 edit_options dict for each edit_opcodes tuple" + non_none_edit_options = [x for x in edit_options if x is not None] + assert len(non_none_edit_options)>0, "missing edit_options" + if len(non_none_edit_options)>1: + print('warning: cross-attention control options are not working properly for >1 edit') + self.edit_options = non_none_edit_options[0] + + class Context: + def __init__(self, arguments: 'CrossAttentionControl.Arguments', step_count: int): + """ + :param arguments: Arguments for the cross-attention control process + :param step_count: The absolute total number of steps of diffusion (for img2img this is likely larger than the number of steps that will actually run) + """ + self.arguments = arguments + self.step_count = step_count + + @classmethod + def remove_cross_attention_control(cls, model): + cls.remove_attention_function(model) + + @classmethod + def setup_cross_attention_control(cls, model, + cross_attention_control_args: Arguments + ): + """ + Inject attention parameters and functions into the passed in model to enable cross attention editing. + + :param model: The unet model to inject into. + :param cross_attention_control_args: Arugments passeed to the CrossAttentionControl implementations + :return: None + """ + + # adapted from init_attention_edit + device = cross_attention_control_args.edited_conditioning.device + + # urgh. should this be hardcoded? + max_length = 77 + # mask=1 means use base prompt attention, mask=0 means use edited prompt attention + mask = torch.zeros(max_length) + indices_target = torch.arange(max_length, dtype=torch.long) + indices = torch.zeros(max_length, dtype=torch.long) + for name, a0, a1, b0, b1 in cross_attention_control_args.edit_opcodes: + if b0 < max_length: + if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0): + # these tokens have not been edited + indices[b0:b1] = indices_target[a0:a1] + mask[b0:b1] = 1 + + for m in cls.get_attention_modules(model, cls.CrossAttentionType.SELF): + m.last_attn_slice_mask = None + m.last_attn_slice_indices = None + + for m in cls.get_attention_modules(model, cls.CrossAttentionType.TOKENS): + m.last_attn_slice_mask = mask.to(device) + m.last_attn_slice_indices = indices.to(device) + + cls.inject_attention_function(model) + + + class CrossAttentionType(Enum): + SELF = 1 + TOKENS = 2 + + @classmethod + def get_active_cross_attention_control_types_for_step(cls, context: 'CrossAttentionControl.Context', percent_through:float=None)\ + -> list['CrossAttentionControl.CrossAttentionType']: + """ + Should cross-attention control be applied on the given step? + :param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0. + :return: A list of attention types that cross-attention control should be performed for on the given step. May be []. + """ + if percent_through is None: + return [cls.CrossAttentionType.SELF, cls.CrossAttentionType.TOKENS] + + opts = context.arguments.edit_options + to_control = [] + if opts['s_start'] <= percent_through and percent_through < opts['s_end']: + to_control.append(cls.CrossAttentionType.SELF) + if opts['t_start'] <= percent_through and percent_through < opts['t_end']: + to_control.append(cls.CrossAttentionType.TOKENS) + return to_control + + + @classmethod + def get_attention_modules(cls, model, which: CrossAttentionType): + which_attn = "attn1" if which is cls.CrossAttentionType.SELF else "attn2" + return [module for name, module in model.named_modules() if + type(module).__name__ == "CrossAttention" and which_attn in name] + + @classmethod + def clear_requests(cls, model): + self_attention_modules = cls.get_attention_modules(model, cls.CrossAttentionType.SELF) + tokens_attention_modules = cls.get_attention_modules(model, cls.CrossAttentionType.TOKENS) + for m in self_attention_modules+tokens_attention_modules: + m.save_last_attn_slice = False + m.use_last_attn_slice = False + + @classmethod + def request_save_attention_maps(cls, model, cross_attention_type: CrossAttentionType): + modules = cls.get_attention_modules(model, cross_attention_type) + for m in modules: + # clear out the saved slice in case the outermost dim changes + m.last_attn_slice = None + m.save_last_attn_slice = True + + @classmethod + def request_apply_saved_attention_maps(cls, model, cross_attention_type: CrossAttentionType): + modules = cls.get_attention_modules(model, cross_attention_type) + for m in modules: + m.use_last_attn_slice = True + + + + @classmethod + def inject_attention_function(cls, unet): + # ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276 + + def attention_slice_wrangler(self, attention_scores, suggested_attention_slice, dim, offset, slice_size): + + #print("in wrangler with suggested_attention_slice shape", suggested_attention_slice.shape, "dim", dim) + + attn_slice = suggested_attention_slice + if dim is not None: + start = offset + end = start+slice_size + #print(f"in wrangler, sliced dim {dim} {start}-{end}, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}") + #else: + # print(f"in wrangler, whole, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}") + + + if self.use_last_attn_slice: + this_attn_slice = attn_slice + if self.last_attn_slice_mask is not None: + # indices and mask operate on dim=2, no need to slice + base_attn_slice_full = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices) + base_attn_slice_mask = self.last_attn_slice_mask + if dim is None: + base_attn_slice = base_attn_slice_full + #print("using whole base slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape) + elif dim == 0: + base_attn_slice = base_attn_slice_full[start:end] + #print("using base dim 0 slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape) + elif dim == 1: + base_attn_slice = base_attn_slice_full[:, start:end] + #print("using base dim 1 slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape) + + attn_slice = this_attn_slice * (1 - base_attn_slice_mask) + \ + base_attn_slice * base_attn_slice_mask + else: + if dim is None: + attn_slice = self.last_attn_slice + #print("took whole slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape) + elif dim == 0: + attn_slice = self.last_attn_slice[start:end] + #print("took dim 0 slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape) + elif dim == 1: + attn_slice = self.last_attn_slice[:, start:end] + #print("took dim 1 slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape) + + if self.save_last_attn_slice: + if dim is None: + self.last_attn_slice = attn_slice + elif dim == 0: + # dynamically grow last_attn_slice if needed + if self.last_attn_slice is None: + self.last_attn_slice = attn_slice + #print("no last_attn_slice: shape now", self.last_attn_slice.shape) + elif self.last_attn_slice.shape[0] == start: + self.last_attn_slice = torch.cat([self.last_attn_slice, attn_slice], dim=0) + assert(self.last_attn_slice.shape[0] == end) + #print("last_attn_slice too small, appended dim 0 shape", attn_slice.shape, ", shape now", self.last_attn_slice.shape) + else: + # no need to grow + self.last_attn_slice[start:end] = attn_slice + #print("last_attn_slice shape is fine, setting dim 0 shape", attn_slice.shape, ", shape now", self.last_attn_slice.shape) + + elif dim == 1: + # dynamically grow last_attn_slice if needed + if self.last_attn_slice is None: + self.last_attn_slice = attn_slice + elif self.last_attn_slice.shape[1] == start: + self.last_attn_slice = torch.cat([self.last_attn_slice, attn_slice], dim=1) + assert(self.last_attn_slice.shape[1] == end) + else: + # no need to grow + self.last_attn_slice[:, start:end] = attn_slice + + if self.use_last_attn_weights and self.last_attn_slice_weights is not None: + if dim is None: + weights = self.last_attn_slice_weights + elif dim == 0: + weights = self.last_attn_slice_weights[start:end] + elif dim == 1: + weights = self.last_attn_slice_weights[:, start:end] + attn_slice = attn_slice * weights + + return attn_slice + + for name, module in unet.named_modules(): + module_name = type(module).__name__ + if module_name == "CrossAttention": + module.last_attn_slice = None + module.last_attn_slice_indices = None + module.last_attn_slice_mask = None + module.use_last_attn_weights = False + module.use_last_attn_slice = False + module.save_last_attn_slice = False + module.set_attention_slice_wrangler(attention_slice_wrangler) + + @classmethod + def remove_attention_function(cls, unet): + for name, module in unet.named_modules(): + module_name = type(module).__name__ + if module_name == "CrossAttention": + module.set_attention_slice_wrangler(None) + diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py index f5dada8627..b11e8578e7 100644 --- a/ldm/models/diffusion/ddim.py +++ b/ldm/models/diffusion/ddim.py @@ -1,10 +1,7 @@ """SAMPLING ONLY.""" import torch -import numpy as np -from tqdm import tqdm -from functools import partial -from ldm.invoke.devices import choose_torch_device +from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from ldm.models.diffusion.sampler import Sampler from ldm.modules.diffusionmodules.util import noise_like @@ -12,6 +9,21 @@ class DDIMSampler(Sampler): def __init__(self, model, schedule='linear', device=None, **kwargs): super().__init__(model,schedule,model.num_timesteps,device) + self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model, + model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond)) + + def prepare_to_sample(self, t_enc, **kwargs): + super().prepare_to_sample(t_enc, **kwargs) + + extra_conditioning_info = kwargs.get('extra_conditioning_info', None) + all_timesteps_count = kwargs.get('all_timesteps_count', t_enc) + + if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: + self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = all_timesteps_count) + else: + self.invokeai_diffuser.remove_cross_attention_control() + + # This is the central routine @torch.no_grad() def p_sample( @@ -29,6 +41,7 @@ class DDIMSampler(Sampler): corrector_kwargs=None, unconditional_guidance_scale=1.0, unconditional_conditioning=None, + step_count:int=1000, # total number of steps **kwargs, ): b, *_, device = *x.shape, x.device @@ -37,15 +50,14 @@ class DDIMSampler(Sampler): unconditional_conditioning is None or unconditional_guidance_scale == 1.0 ): + # damian0815 would like to know when/if this code path is used e_t = self.model.apply_model(x, t, c) else: - x_in = torch.cat([x] * 2) - t_in = torch.cat([t] * 2) - c_in = torch.cat([unconditional_conditioning, c]) - e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) - e_t = e_t_uncond + unconditional_guidance_scale * ( - e_t - e_t_uncond - ) + step_index = step_count-(index+1) + e_t = self.invokeai_diffuser.do_diffusion_step(x, t, + unconditional_conditioning, c, + unconditional_guidance_scale, + step_index=step_index) if score_corrector is not None: assert self.model.parameterization == 'eps' diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py index 4b62b5e393..57027b224c 100644 --- a/ldm/models/diffusion/ddpm.py +++ b/ldm/models/diffusion/ddpm.py @@ -820,21 +820,21 @@ class LatentDiffusion(DDPM): ) return self.scale_factor * z - def get_learned_conditioning(self, c): + def get_learned_conditioning(self, c, **kwargs): if self.cond_stage_forward is None: if hasattr(self.cond_stage_model, 'encode') and callable( self.cond_stage_model.encode ): c = self.cond_stage_model.encode( - c, embedding_manager=self.embedding_manager + c, embedding_manager=self.embedding_manager, **kwargs ) if isinstance(c, DiagonalGaussianDistribution): c = c.mode() else: - c = self.cond_stage_model(c) + c = self.cond_stage_model(c, **kwargs) else: assert hasattr(self.cond_stage_model, self.cond_stage_forward) - c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) + c = getattr(self.cond_stage_model, self.cond_stage_forward)(c, **kwargs) return c def meshgrid(self, h, w): diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index ac0615b30c..2f5bf53850 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -1,16 +1,12 @@ """wrapper around part of Katherine Crowson's k-diffusion library, making it call compatible with other Samplers""" + import k_diffusion as K import torch -import torch.nn as nn -from ldm.invoke.devices import choose_torch_device -from ldm.models.diffusion.sampler import Sampler -from ldm.util import rand_perlin_2d -from ldm.modules.diffusionmodules.util import ( - make_ddim_sampling_parameters, - make_ddim_timesteps, - noise_like, - extract_into_tensor, -) +from torch import nn + +from .sampler import Sampler +from .shared_invokeai_diffusion import InvokeAIDiffuserComponent + def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7): if threshold <= 0.0: @@ -33,12 +29,24 @@ class CFGDenoiser(nn.Module): self.threshold = threshold self.warmup_max = warmup self.warmup = max(warmup / 10, 1) + self.invokeai_diffuser = InvokeAIDiffuserComponent(model, + model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond)) + + def prepare_to_sample(self, t_enc, **kwargs): + + extra_conditioning_info = kwargs.get('extra_conditioning_info', None) + + if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: + self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = t_enc) + else: + self.invokeai_diffuser.remove_cross_attention_control() + def forward(self, x, sigma, uncond, cond, cond_scale): - x_in = torch.cat([x] * 2) - sigma_in = torch.cat([sigma] * 2) - cond_in = torch.cat([uncond, cond]) - uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) + + next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale) + + # apply threshold if self.warmup < self.warmup_max: thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max)) self.warmup += 1 @@ -46,7 +54,8 @@ class CFGDenoiser(nn.Module): thresh = self.threshold if thresh > self.threshold: thresh = self.threshold - return cfg_apply_threshold(uncond + (cond - uncond) * cond_scale, thresh) + return cfg_apply_threshold(next_x, thresh) + class KSampler(Sampler): @@ -61,16 +70,6 @@ class KSampler(Sampler): self.ds = None self.s_in = None - def forward(self, x, sigma, uncond, cond, cond_scale): - x_in = torch.cat([x] * 2) - sigma_in = torch.cat([sigma] * 2) - cond_in = torch.cat([uncond, cond]) - uncond, cond = self.inner_model( - x_in, sigma_in, cond=cond_in - ).chunk(2) - return uncond + (cond - uncond) * cond_scale - - def make_schedule( self, ddim_num_steps, @@ -118,6 +117,7 @@ class KSampler(Sampler): use_original_steps=False, init_latent = None, mask = None, + **kwargs ): samples,_ = self.sample( batch_size = 1, @@ -129,7 +129,8 @@ class KSampler(Sampler): unconditional_conditioning = unconditional_conditioning, img_callback = img_callback, x0 = init_latent, - mask = mask + mask = mask, + **kwargs ) return samples @@ -163,6 +164,7 @@ class KSampler(Sampler): log_every_t=100, unconditional_guidance_scale=1.0, unconditional_conditioning=None, + extra_conditioning_info=None, threshold = 0, perlin = 0, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... @@ -181,7 +183,6 @@ class KSampler(Sampler): ) # sigmas are set up in make_schedule - we take the last steps items - total_steps = len(self.sigmas) sigmas = self.sigmas[-S-1:] # x_T is variation noise. When an init image is provided (in x0) we need to add @@ -195,19 +196,21 @@ class KSampler(Sampler): x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10)) + model_wrap_cfg.prepare_to_sample(S, extra_conditioning_info=extra_conditioning_info) extra_args = { 'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale, } print(f'>> Sampling with k_{self.schedule} starting at step {len(self.sigmas)-S-1} of {len(self.sigmas)-1} ({S} new sampling steps)') - return ( + sampling_result = ( K.sampling.__dict__[f'sample_{self.schedule}']( model_wrap_cfg, x, sigmas, extra_args=extra_args, callback=route_callback ), None, ) + return sampling_result # this code will support inpainting if and when ksampler API modified or # a workaround is found. @@ -220,6 +223,7 @@ class KSampler(Sampler): index, unconditional_guidance_scale=1.0, unconditional_conditioning=None, + extra_conditioning_info=None, **kwargs, ): if self.model_wrap is None: @@ -245,6 +249,7 @@ class KSampler(Sampler): # so the actual formula for indexing into sigmas: # sigma_index = (steps-index) s_index = t_enc - index - 1 + self.model_wrap.prepare_to_sample(s_index, extra_conditioning_info=extra_conditioning_info) img = K.sampling.__dict__[f'_{self.schedule}']( self.model_wrap, img, @@ -269,7 +274,7 @@ class KSampler(Sampler): else: return x - def prepare_to_sample(self,t_enc): + def prepare_to_sample(self,t_enc,**kwargs): self.t_enc = t_enc self.model_wrap = None self.ds = None diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py index 9e722eb932..6bd519b63b 100644 --- a/ldm/models/diffusion/plms.py +++ b/ldm/models/diffusion/plms.py @@ -5,6 +5,7 @@ import numpy as np from tqdm import tqdm from functools import partial from ldm.invoke.devices import choose_torch_device +from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from ldm.models.diffusion.sampler import Sampler from ldm.modules.diffusionmodules.util import noise_like @@ -13,6 +14,21 @@ class PLMSSampler(Sampler): def __init__(self, model, schedule='linear', device=None, **kwargs): super().__init__(model,schedule,model.num_timesteps, device) + self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model, + model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond)) + + def prepare_to_sample(self, t_enc, **kwargs): + super().prepare_to_sample(t_enc, **kwargs) + + extra_conditioning_info = kwargs.get('extra_conditioning_info', None) + all_timesteps_count = kwargs.get('all_timesteps_count', t_enc) + + if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: + self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = all_timesteps_count) + else: + self.invokeai_diffuser.remove_cross_attention_control() + + # this is the essential routine @torch.no_grad() def p_sample( @@ -32,6 +48,7 @@ class PLMSSampler(Sampler): unconditional_conditioning=None, old_eps=[], t_next=None, + step_count:int=1000, # total number of steps **kwargs, ): b, *_, device = *x.shape, x.device @@ -41,17 +58,15 @@ class PLMSSampler(Sampler): unconditional_conditioning is None or unconditional_guidance_scale == 1.0 ): + # damian0815 would like to know when/if this code path is used e_t = self.model.apply_model(x, t, c) else: - x_in = torch.cat([x] * 2) - t_in = torch.cat([t] * 2) - c_in = torch.cat([unconditional_conditioning, c]) - e_t_uncond, e_t = self.model.apply_model( - x_in, t_in, c_in - ).chunk(2) - e_t = e_t_uncond + unconditional_guidance_scale * ( - e_t - e_t_uncond - ) + # step_index counts in the opposite direction to index + step_index = step_count-(index+1) + e_t = self.invokeai_diffuser.do_diffusion_step(x, t, + unconditional_conditioning, c, + unconditional_guidance_scale, + step_index=step_index) if score_corrector is not None: assert self.model.parameterization == 'eps' diff --git a/ldm/models/diffusion/sampler.py b/ldm/models/diffusion/sampler.py index ff705513f8..853702ef68 100644 --- a/ldm/models/diffusion/sampler.py +++ b/ldm/models/diffusion/sampler.py @@ -4,6 +4,8 @@ ldm.models.diffusion.sampler Base class for ldm.models.diffusion.ddim, ldm.models.diffusion.ksampler, etc ''' +from math import ceil + import torch import numpy as np from tqdm import tqdm @@ -190,6 +192,7 @@ class Sampler(object): unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, steps=S, + **kwargs ) return samples, intermediates @@ -214,6 +217,7 @@ class Sampler(object): unconditional_guidance_scale=1.0, unconditional_conditioning=None, steps=None, + **kwargs ): b = shape[0] time_range = ( @@ -231,7 +235,7 @@ class Sampler(object): dynamic_ncols=True, ) old_eps = [] - self.prepare_to_sample(t_enc=total_steps) + self.prepare_to_sample(t_enc=total_steps,all_timesteps_count=steps,**kwargs) img = self.get_initial_image(x_T,shape,total_steps) # probably don't need this at all @@ -274,6 +278,7 @@ class Sampler(object): unconditional_conditioning=unconditional_conditioning, old_eps=old_eps, t_next=ts_next, + step_count=steps ) img, pred_x0, e_t = outs @@ -305,6 +310,8 @@ class Sampler(object): use_original_steps=False, init_latent = None, mask = None, + all_timesteps_count = None, + **kwargs ): timesteps = ( @@ -321,7 +328,7 @@ class Sampler(object): iterator = tqdm(time_range, desc='Decoding image', total=total_steps) x_dec = x_latent x0 = init_latent - self.prepare_to_sample(t_enc=total_steps) + self.prepare_to_sample(t_enc=total_steps, all_timesteps_count=all_timesteps_count, **kwargs) for i, step in enumerate(iterator): index = total_steps - i - 1 @@ -353,6 +360,7 @@ class Sampler(object): unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning, t_next = ts_next, + step_count=len(self.ddim_timesteps) ) x_dec, pred_x0, e_t = outs @@ -411,3 +419,4 @@ class Sampler(object): return self.model.inner_model.q_sample(x0,ts) ''' return self.model.q_sample(x0,ts) + diff --git a/ldm/models/diffusion/shared_invokeai_diffusion.py b/ldm/models/diffusion/shared_invokeai_diffusion.py new file mode 100644 index 0000000000..b8a7a04d0e --- /dev/null +++ b/ldm/models/diffusion/shared_invokeai_diffusion.py @@ -0,0 +1,176 @@ +from math import ceil +from typing import Callable, Optional + +import torch + +from ldm.models.diffusion.cross_attention_control import CrossAttentionControl + + +class InvokeAIDiffuserComponent: + ''' + The aim of this component is to provide a single place for code that can be applied identically to + all InvokeAI diffusion procedures. + + At the moment it includes the following features: + * Cross Attention Control ("prompt2prompt") + ''' + + + class ExtraConditioningInfo: + def __init__(self, cross_attention_control_args: Optional[CrossAttentionControl.Arguments]): + self.cross_attention_control_args = cross_attention_control_args + + @property + def wants_cross_attention_control(self): + return self.cross_attention_control_args is not None + + def __init__(self, model, model_forward_callback: + Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor] + ): + """ + :param model: the unet model to pass through to cross attention control + :param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning) + """ + self.model = model + self.model_forward_callback = model_forward_callback + + + def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int): + self.conditioning = conditioning + self.cross_attention_control_context = CrossAttentionControl.Context( + arguments=self.conditioning.cross_attention_control_args, + step_count=step_count + ) + CrossAttentionControl.setup_cross_attention_control(self.model, + cross_attention_control_args=self.conditioning.cross_attention_control_args + ) + #todo: refactor edited_conditioning, edit_opcodes, edit_options into a struct + #todo: apply edit_options using step_count + + + def remove_cross_attention_control(self): + self.conditioning = None + self.cross_attention_control_context = None + CrossAttentionControl.remove_cross_attention_control(self.model) + + def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor, + unconditioning: torch.Tensor, conditioning: torch.Tensor, + unconditional_guidance_scale: float, + step_index: int=None + ): + """ + :param x: Current latents + :param sigma: aka t, passed to the internal model to control how much denoising will occur + :param unconditioning: [B x 77 x 768] embeddings for unconditioned output + :param conditioning: [B x 77 x 768] embeddings for conditioned output + :param unconditional_guidance_scale: aka CFG scale, controls how much effect the conditioning tensor has + :param step_index: Counts upwards from 0 to (step_count-1) (as passed to setup_cross_attention_control, if using). May be called multiple times for a single step, therefore do not assume that its value will monotically increase. + :return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning. + """ + + CrossAttentionControl.clear_requests(self.model) + cross_attention_control_types_to_do = [] + + if self.cross_attention_control_context is not None: + if step_index is not None: + # percent_through will never reach 1.0 (but this is intended) + percent_through = float(step_index) / float(self.cross_attention_control_context.step_count) + else: + # find the current sigma in the sigma sequence + # todo: this doesn't work with k_dpm_2 because the sigma used jumps around in the sequence + sigma_index = torch.nonzero(self.model.sigmas <= sigma)[-1] + # flip because sigmas[0] is for the fully denoised image + # percent_through must be <1 + percent_through = 1.0 - float(sigma_index.item() + 1) / float(self.model.sigmas.shape[0]) + #print('estimated percent_through', percent_through, 'from sigma', sigma.item()) + cross_attention_control_types_to_do = CrossAttentionControl.get_active_cross_attention_control_types_for_step(self.cross_attention_control_context, percent_through) + + if len(cross_attention_control_types_to_do)==0: + #print('not doing cross attention control') + # faster batched path + x_twice = torch.cat([x]*2) + sigma_twice = torch.cat([sigma]*2) + both_conditionings = torch.cat([unconditioning, conditioning]) + unconditioned_next_x, conditioned_next_x = self.model_forward_callback(x_twice, sigma_twice, both_conditionings).chunk(2) + else: + #print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do) + # slower non-batched path (20% slower on mac MPS) + # We are only interested in using attention maps for conditioned_next_x, but batching them with generation of + # unconditioned_next_x causes attention maps to *also* be saved for the unconditioned_next_x. + # This messes app their application later, due to mismatched shape of dim 0 (seems to be 16 for batched vs. 8) + # (For the batched invocation the `wrangler` function gets attention tensor with shape[0]=16, + # representing batched uncond + cond, but then when it comes to applying the saved attention, the + # wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.) + # todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well. + unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning) + + # process x using the original prompt, saving the attention maps + for type in cross_attention_control_types_to_do: + CrossAttentionControl.request_save_attention_maps(self.model, type) + _ = self.model_forward_callback(x, sigma, conditioning) + CrossAttentionControl.clear_requests(self.model) + + # process x again, using the saved attention maps to control where self.edited_conditioning will be applied + for type in cross_attention_control_types_to_do: + CrossAttentionControl.request_apply_saved_attention_maps(self.model, type) + edited_conditioning = self.conditioning.cross_attention_control_args.edited_conditioning + conditioned_next_x = self.model_forward_callback(x, sigma, edited_conditioning) + CrossAttentionControl.clear_requests(self.model) + + + # to scale how much effect conditioning has, calculate the changes it does and then scale that + scaled_delta = (conditioned_next_x - unconditioned_next_x) * unconditional_guidance_scale + combined_next_x = unconditioned_next_x + scaled_delta + + return combined_next_x + + # todo: make this work + @classmethod + def apply_conjunction(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale): + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) # aka sigmas + + deltas = None + uncond_latents = None + weighted_cond_list = c_or_weighted_c_list if type(c_or_weighted_c_list) is list else [(c_or_weighted_c_list, 1)] + + # below is fugly omg + num_actual_conditionings = len(c_or_weighted_c_list) + conditionings = [uc] + [c for c,weight in weighted_cond_list] + weights = [1] + [weight for c,weight in weighted_cond_list] + chunk_count = ceil(len(conditionings)/2) + deltas = None + for chunk_index in range(chunk_count): + offset = chunk_index*2 + chunk_size = min(2, len(conditionings)-offset) + + if chunk_size == 1: + c_in = conditionings[offset] + latents_a = forward_func(x_in[:-1], t_in[:-1], c_in) + latents_b = None + else: + c_in = torch.cat(conditionings[offset:offset+2]) + latents_a, latents_b = forward_func(x_in, t_in, c_in).chunk(2) + + # first chunk is guaranteed to be 2 entries: uncond_latents + first conditioining + if chunk_index == 0: + uncond_latents = latents_a + deltas = latents_b - uncond_latents + else: + deltas = torch.cat((deltas, latents_a - uncond_latents)) + if latents_b is not None: + deltas = torch.cat((deltas, latents_b - uncond_latents)) + + # merge the weighted deltas together into a single merged delta + per_delta_weights = torch.tensor(weights[1:], dtype=deltas.dtype, device=deltas.device) + normalize = False + if normalize: + per_delta_weights /= torch.sum(per_delta_weights) + reshaped_weights = per_delta_weights.reshape(per_delta_weights.shape + (1, 1, 1)) + deltas_merged = torch.sum(deltas * reshaped_weights, dim=0, keepdim=True) + + # old_return_value = super().forward(x, sigma, uncond, cond, cond_scale) + # assert(0 == len(torch.nonzero(old_return_value - (uncond_latents + deltas_merged * cond_scale)))) + + return uncond_latents + deltas_merged * global_guidance_scale + diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index ef9c2d3e65..8d160f004b 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -1,5 +1,7 @@ from inspect import isfunction import math +from typing import Callable + import torch import torch.nn.functional as F from torch import nn, einsum @@ -150,6 +152,7 @@ class SpatialSelfAttention(nn.Module): return x+h_ + class CrossAttention(nn.Module): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): super().__init__() @@ -170,46 +173,73 @@ class CrossAttention(nn.Module): self.mem_total_gb = psutil.virtual_memory().total // (1 << 30) - def einsum_op_compvis(self, q, k, v): - s = einsum('b i d, b j d -> b i j', q, k) - s = s.softmax(dim=-1, dtype=s.dtype) - return einsum('b i j, b j d -> b i d', s, v) + self.attention_slice_wrangler = None - def einsum_op_slice_0(self, q, k, v, slice_size): + def set_attention_slice_wrangler(self, wrangler:Callable[[nn.Module, torch.Tensor, torch.Tensor, int, int, int], torch.Tensor]): + ''' + Set custom attention calculator to be called when attention is calculated + :param wrangler: Callback, with args (self, attention_scores, suggested_attention_slice, dim, offset, slice_size), + which returns either the suggested_attention_slice or an adjusted equivalent. + self is the current CrossAttention module for which the callback is being invoked. + attention_scores are the scores for attention + suggested_attention_slice is a softmax(dim=-1) over attention_scores + dim is -1 if the call is non-sliced, or 0 or 1 for dimension-0 or dimension-1 slicing. + If dim is >= 0, offset and slice_size specify the slice start and length. + + Pass None to use the default attention calculation. + :return: + ''' + self.attention_slice_wrangler = wrangler + + def einsum_lowest_level(self, q, k, v, dim, offset, slice_size): + # calculate attention scores + attention_scores = einsum('b i d, b j d -> b i j', q, k) + # calculate attenion slice by taking the best scores for each latent pixel + default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype) + if self.attention_slice_wrangler is not None: + attention_slice = self.attention_slice_wrangler(self, attention_scores, default_attention_slice, dim, offset, slice_size) + else: + attention_slice = default_attention_slice + + return einsum('b i j, b j d -> b i d', attention_slice, v) + + def einsum_op_slice_dim0(self, q, k, v, slice_size): r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) for i in range(0, q.shape[0], slice_size): end = i + slice_size - r[i:end] = self.einsum_op_compvis(q[i:end], k[i:end], v[i:end]) + r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size) return r - def einsum_op_slice_1(self, q, k, v, slice_size): + def einsum_op_slice_dim1(self, q, k, v, slice_size): r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) for i in range(0, q.shape[1], slice_size): end = i + slice_size - r[:, i:end] = self.einsum_op_compvis(q[:, i:end], k, v) + r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size) return r def einsum_op_mps_v1(self, q, k, v): if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096 - return self.einsum_op_compvis(q, k, v) + return self.einsum_lowest_level(q, k, v, None, None, None) else: slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) - return self.einsum_op_slice_1(q, k, v, slice_size) + return self.einsum_op_slice_dim1(q, k, v, slice_size) def einsum_op_mps_v2(self, q, k, v): if self.mem_total_gb > 8 and q.shape[1] <= 4096: - return self.einsum_op_compvis(q, k, v) + return self.einsum_lowest_level(q, k, v, None, None, None) else: - return self.einsum_op_slice_0(q, k, v, 1) + return self.einsum_op_slice_dim0(q, k, v, 1) def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb): size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20) if size_mb <= max_tensor_mb: - return self.einsum_op_compvis(q, k, v) + return self.einsum_lowest_level(q, k, v, None, None, None) div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length() if div <= q.shape[0]: - return self.einsum_op_slice_0(q, k, v, q.shape[0] // div) - return self.einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1)) + print("warning: untested call to einsum_op_slice_dim0") + return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div) + print("warning: untested call to einsum_op_slice_dim1") + return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1)) def einsum_op_cuda(self, q, k, v): stats = torch.cuda.memory_stats(q.device) @@ -221,7 +251,7 @@ class CrossAttention(nn.Module): # Divide factor of safety as there's copying and fragmentation return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20)) - def einsum_op(self, q, k, v): + def get_attention_mem_efficient(self, q, k, v): if q.device.type == 'cuda': return self.einsum_op_cuda(q, k, v) @@ -244,8 +274,13 @@ class CrossAttention(nn.Module): del context, x q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - r = self.einsum_op(q, k, v) - return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h)) + + r = self.get_attention_mem_efficient(q, k, v) + + hidden_states = rearrange(r, '(b h) n d -> b n (h d)', h=h) + return self.to_out(hidden_states) + + class BasicTransformerBlock(nn.Module): diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index 426fccced3..8917a27a40 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -1,3 +1,5 @@ +import math + import torch import torch.nn as nn from functools import partial @@ -437,6 +439,7 @@ class FrozenCLIPEmbedder(AbstractEncoder): param.requires_grad = False def forward(self, text, **kwargs): + batch_encoding = self.tokenizer( text, truncation=True, @@ -454,6 +457,222 @@ class FrozenCLIPEmbedder(AbstractEncoder): def encode(self, text, **kwargs): return self(text, **kwargs) +class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder): + + fragment_weights_key = "fragment_weights" + return_tokens_key = "return_tokens" + + def forward(self, text: list, **kwargs): + ''' + + :param text: A batch of prompt strings, or, a batch of lists of fragments of prompt strings to which different + weights shall be applied. + :param kwargs: If the keyword arg "fragment_weights" is passed, it shall contain a batch of lists of weights + for the prompt fragments. In this case text must contain batches of lists of prompt fragments. + :return: A tensor of shape (B, 77, 768) containing weighted embeddings + ''' + if self.fragment_weights_key not in kwargs: + # fallback to base class implementation + return super().forward(text, **kwargs) + + fragment_weights = kwargs[self.fragment_weights_key] + # self.transformer doesn't like receiving "fragment_weights" as an argument + kwargs.pop(self.fragment_weights_key) + + should_return_tokens = False + if self.return_tokens_key in kwargs: + should_return_tokens = kwargs.get(self.return_tokens_key, False) + # self.transformer doesn't like having extra kwargs + kwargs.pop(self.return_tokens_key) + + batch_z = None + batch_tokens = None + for fragments, weights in zip(text, fragment_weights): + + # First, weight tokens in individual fragments by scaling the feature vectors as requested (effectively + # applying a multiplier to the CFG scale on a per-token basis). + # For tokens weighted<1, intuitively we want SD to become not merely *less* interested in the concept + # captured by the fragment but actually *dis*interested in it (a 0.01 interest in "red" is still an active + # interest, however small, in redness; what the user probably intends when they attach the number 0.01 to + # "red" is to tell SD that it should almost completely *ignore* redness). + # To do this, the embedding is lerped away from base_embedding in the direction of an embedding for a prompt + # string from which the low-weighted fragment has been simply removed. The closer the weight is to zero, the + # closer the resulting embedding is to an embedding for a prompt that simply lacks this fragment. + + # handle weights >=1 + tokens, per_token_weights = self.get_tokens_and_weights(fragments, weights) + base_embedding = self.build_weighted_embedding_tensor(tokens, per_token_weights, **kwargs) + + # this is our starting point + embeddings = base_embedding.unsqueeze(0) + per_embedding_weights = [1.0] + + # now handle weights <1 + # Do this by building extra embeddings tensors that lack the words being <1 weighted. These will be lerped + # with the embeddings tensors that have the words, such that if the weight of a word is 0.5, the resulting + # embedding will be exactly half-way between the unweighted prompt and the prompt with the <1 weighted words + # removed. + # eg for "mountain:1 man:0.5", intuitively the "man" should be "half-gone". therefore, append an embedding + # for "mountain" (i.e. without "man") to the already-produced embedding for "mountain man", and weight it + # such that the resulting lerped embedding is exactly half-way between "mountain man" and "mountain". + for index, fragment_weight in enumerate(weights): + if fragment_weight < 1: + fragments_without_this = fragments[:index] + fragments[index+1:] + weights_without_this = weights[:index] + weights[index+1:] + tokens, per_token_weights = self.get_tokens_and_weights(fragments_without_this, weights_without_this) + embedding_without_this = self.build_weighted_embedding_tensor(tokens, per_token_weights, **kwargs) + + embeddings = torch.cat((embeddings, embedding_without_this.unsqueeze(0)), dim=1) + # weight of the embedding *without* this fragment gets *stronger* as its weight approaches 0 + # if fragment_weight = 0, basically we want embedding_without_this to completely overwhelm base_embedding + # therefore: + # fragment_weight = 1: we are at base_z => lerp weight 0 + # fragment_weight = 0.5: we are halfway between base_z and here => lerp weight 1 + # fragment_weight = 0: we're now entirely overriding base_z ==> lerp weight inf + # so let's use tan(), because: + # tan is 0.0 at 0, + # 1.0 at PI/4, and + # inf at PI/2 + # -> tan((1-weight)*PI/2) should give us ideal lerp weights + epsilon = 1e-9 + fragment_weight = max(epsilon, fragment_weight) # inf is bad + embedding_lerp_weight = math.tan((1.0 - fragment_weight) * math.pi / 2) + # todo handle negative weight? + + per_embedding_weights.append(embedding_lerp_weight) + + lerped_embeddings = self.apply_embedding_weights(embeddings, per_embedding_weights, normalize=True).squeeze(0) + + #print(f"assembled tokens for '{fragments}' into tensor of shape {lerped_embeddings.shape}") + + # append to batch + batch_z = lerped_embeddings.unsqueeze(0) if batch_z is None else torch.cat([batch_z, lerped_embeddings.unsqueeze(0)], dim=1) + batch_tokens = tokens.unsqueeze(0) if batch_tokens is None else torch.cat([batch_tokens, tokens.unsqueeze(0)], dim=1) + + # should have shape (B, 77, 768) + #print(f"assembled all tokens into tensor of shape {batch_z.shape}") + + if should_return_tokens: + return batch_z, batch_tokens + else: + return batch_z + + def get_tokens(self, fragments: list[str], include_start_and_end_markers: bool = True) -> list[list[int]]: + tokens = self.tokenizer( + fragments, + truncation=True, + max_length=self.max_length, + return_overflowing_tokens=False, + padding='do_not_pad', + return_tensors=None, # just give me a list of ints + )['input_ids'] + if include_start_and_end_markers: + return tokens + else: + return [x[1:-1] for x in tokens] + + + @classmethod + def apply_embedding_weights(self, embeddings: torch.Tensor, per_embedding_weights: list[float], normalize:bool) -> torch.Tensor: + per_embedding_weights = torch.tensor(per_embedding_weights, dtype=embeddings.dtype, device=embeddings.device) + if normalize: + per_embedding_weights = per_embedding_weights / torch.sum(per_embedding_weights) + reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1, 1,)) + #reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1,1,)).expand(embeddings.shape) + return torch.sum(embeddings * reshaped_weights, dim=1) + # lerped embeddings has shape (77, 768) + + + def get_tokens_and_weights(self, fragments: list[str], weights: list[float]) -> (torch.Tensor, torch.Tensor): + ''' + + :param fragments: + :param weights: Per-fragment weights (CFG scaling). No need for these to be normalized. They will not be normalized here and that's fine. + :return: + ''' + # empty is meaningful + if len(fragments) == 0 and len(weights) == 0: + fragments = [''] + weights = [1] + item_encodings = self.tokenizer( + fragments, + truncation=True, + max_length=self.max_length, + return_overflowing_tokens=False, + padding='do_not_pad', + return_tensors=None, # just give me a list of ints + )['input_ids'] + all_tokens = [] + per_token_weights = [] + #print("all fragments:", fragments, weights) + for index, fragment in enumerate(item_encodings): + weight = weights[index] + #print("processing fragment", fragment, weight) + fragment_tokens = item_encodings[index] + #print("fragment", fragment, "processed to", fragment_tokens) + # trim bos and eos markers before appending + all_tokens.extend(fragment_tokens[1:-1]) + per_token_weights.extend([weight] * (len(fragment_tokens) - 2)) + + if (len(all_tokens) + 2) > self.max_length: + excess_token_count = (len(all_tokens) + 2) - self.max_length + print(f"prompt is {excess_token_count} token(s) too long and has been truncated") + all_tokens = all_tokens[:self.max_length - 2] + + # pad out to a 77-entry array: [eos_token, , eos_token, ..., eos_token] + # (77 = self.max_length) + pad_length = self.max_length - 1 - len(all_tokens) + all_tokens.insert(0, self.tokenizer.bos_token_id) + all_tokens.extend([self.tokenizer.eos_token_id] * pad_length) + per_token_weights.insert(0, 1) + per_token_weights.extend([1] * pad_length) + + all_tokens_tensor = torch.tensor(all_tokens, dtype=torch.long).to(self.device) + per_token_weights_tensor = torch.tensor(per_token_weights, dtype=torch.float32).to(self.device) + #print(f"assembled all_tokens_tensor with shape {all_tokens_tensor.shape}") + return all_tokens_tensor, per_token_weights_tensor + + def build_weighted_embedding_tensor(self, tokens: torch.Tensor, per_token_weights: torch.Tensor, weight_delta_from_empty=True, **kwargs) -> torch.Tensor: + ''' + Build a tensor representing the passed-in tokens, each of which has a weight. + :param tokens: A tensor of shape (77) containing token ids (integers) + :param per_token_weights: A tensor of shape (77) containing weights (floats) + :param method: Whether to multiply the whole feature vector for each token or just its distance from an "empty" feature vector + :param kwargs: passed on to self.transformer() + :return: A tensor of shape (1, 77, 768) representing the requested weighted embeddings. + ''' + #print(f"building weighted embedding tensor for {tokens} with weights {per_token_weights}") + z = self.transformer(input_ids=tokens.unsqueeze(0), **kwargs) + batch_weights_expanded = per_token_weights.reshape(per_token_weights.shape + (1,)).expand(z.shape) + + if weight_delta_from_empty: + empty_tokens = self.tokenizer([''] * z.shape[0], + truncation=True, + max_length=self.max_length, + padding='max_length', + return_tensors='pt' + )['input_ids'].to(self.device) + empty_z = self.transformer(input_ids=empty_tokens, **kwargs) + z_delta_from_empty = z - empty_z + weighted_z = empty_z + (z_delta_from_empty * batch_weights_expanded) + + weighted_z_delta_from_empty = (weighted_z-empty_z) + #print("weighted z has delta from empty with sum", weighted_z_delta_from_empty.sum().item(), "mean", weighted_z_delta_from_empty.mean().item() ) + + #print("using empty-delta method, first 5 rows:") + #print(weighted_z[:5]) + + return weighted_z + + else: + original_mean = z.mean() + z *= batch_weights_expanded + after_weighting_mean = z.mean() + # correct the mean. not sure if this is right but it's what the automatic1111 fork of SD does + mean_correction_factor = original_mean/after_weighting_mean + z *= mean_correction_factor + return z + class FrozenCLIPTextEmbedder(nn.Module): """ diff --git a/tests/test_prompt_parser.py b/tests/test_prompt_parser.py new file mode 100644 index 0000000000..486265d2f5 --- /dev/null +++ b/tests/test_prompt_parser.py @@ -0,0 +1,401 @@ +import unittest + +import pyparsing + +from ldm.invoke.prompt_parser import PromptParser, Blend, Conjunction, FlattenedPrompt, CrossAttentionControlSubstitute, \ + Fragment + + +def parse_prompt(prompt_string): + pp = PromptParser() + #print(f"parsing '{prompt_string}'") + parse_result = pp.parse_conjunction(prompt_string) + #print(f"-> parsed '{prompt_string}' to {parse_result}") + return parse_result + +def make_basic_conjunction(strings: list[str]): + fragments = [Fragment(x) for x in strings] + return Conjunction([FlattenedPrompt(fragments)]) + +def make_weighted_conjunction(weighted_strings: list[tuple[str,float]]): + fragments = [Fragment(x, w) for x,w in weighted_strings] + return Conjunction([FlattenedPrompt(fragments)]) + + +class PromptParserTestCase(unittest.TestCase): + + def test_empty(self): + self.assertEqual(make_weighted_conjunction([('', 1)]), parse_prompt('')) + + def test_basic(self): + self.assertEqual(make_weighted_conjunction([('fire flames', 1)]), parse_prompt("fire (flames)")) + self.assertEqual(make_weighted_conjunction([("fire flames", 1)]), parse_prompt("fire flames")) + self.assertEqual(make_weighted_conjunction([("fire, flames", 1)]), parse_prompt("fire, flames")) + self.assertEqual(make_weighted_conjunction([("fire, flames , fire", 1)]), parse_prompt("fire, flames , fire")) + + def test_attention(self): + self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("(flames)0.5")) + self.assertEqual(make_weighted_conjunction([('fire flames', 0.5)]), parse_prompt("(fire flames)0.5")) + self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("(flames)+")) + self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("flames+")) + self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("\"flames\"+")) + self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("(flames)-")) + self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("flames-")) + self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("\"flames\"-")) + self.assertEqual(make_weighted_conjunction([('fire', 1), ('flames', 0.5)]), parse_prompt("fire (flames)0.5")) + self.assertEqual(make_weighted_conjunction([('flames', pow(1.1, 2))]), parse_prompt("(flames)++")) + self.assertEqual(make_weighted_conjunction([('flames', pow(0.9, 2))]), parse_prompt("(flames)--")) + self.assertEqual(make_weighted_conjunction([('flowers', pow(0.9, 3)), ('flames', pow(1.1, 3))]), parse_prompt("(flowers)--- flames+++")) + self.assertEqual(make_weighted_conjunction([('pretty flowers', 1.1)]), + parse_prompt("(pretty flowers)+")) + self.assertEqual(make_weighted_conjunction([('pretty flowers', 1.1), (', the flames are too hot', 1)]), + parse_prompt("(pretty flowers)+, the flames are too hot")) + + def test_no_parens_attention_runon(self): + self.assertEqual(make_weighted_conjunction([('fire', 1.0), ('flames', pow(1.1, 2))]), parse_prompt("fire flames++")) + self.assertEqual(make_weighted_conjunction([('fire', 1.0), ('flames', pow(0.9, 2))]), parse_prompt("fire flames--")) + self.assertEqual(make_weighted_conjunction([('flowers', 1.0), ('fire', pow(1.1, 2)), ('flames', 1.0)]), parse_prompt("flowers fire++ flames")) + self.assertEqual(make_weighted_conjunction([('flowers', 1.0), ('fire', pow(0.9, 2)), ('flames', 1.0)]), parse_prompt("flowers fire-- flames")) + + + def test_explicit_conjunction(self): + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])]), parse_prompt('("fire", "flames").and(1,1)')) + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])]), parse_prompt('("fire", "flames").and()')) + self.assertEqual( + Conjunction([FlattenedPrompt([('fire flames', 1.0)]), FlattenedPrompt([('mountain man', 1.0)])]), parse_prompt('("fire flames", "mountain man").and()')) + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 2.0)]), FlattenedPrompt([('flames', 0.9)])]), parse_prompt('("(fire)2.0", "flames-").and()')) + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)]), + FlattenedPrompt([('mountain man', 1.0)])]), parse_prompt('("fire", "flames", "mountain man").and()')) + + def test_conjunction_weights(self): + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])], weights=[2.0,1.0]), parse_prompt('("fire", "flames").and(2,1)')) + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])], weights=[1.0,2.0]), parse_prompt('("fire", "flames").and(1,2)')) + + with self.assertRaises(PromptParser.ParsingException): + parse_prompt('("fire", "flames").and(2)') + parse_prompt('("fire", "flames").and(2,1,2)') + + def test_complex_conjunction(self): + + #print(parse_prompt("a person with a hat (riding a bicycle.swap(skateboard))++")) + + self.assertEqual(Conjunction([FlattenedPrompt([("mountain man", 1.0)]), FlattenedPrompt([("a person with a hat", 1.0), ("riding a bicycle", pow(1.1,2))])], weights=[0.5, 0.5]), + parse_prompt("(\"mountain man\", \"a person with a hat (riding a bicycle)++\").and(0.5, 0.5)")) + self.assertEqual(Conjunction([FlattenedPrompt([("mountain man", 1.0)]), + FlattenedPrompt([("a person with a hat", 1.0), + ("riding a", 1.1*1.1), + CrossAttentionControlSubstitute( + [Fragment("bicycle", pow(1.1,2))], + [Fragment("skateboard", pow(1.1,2))]) + ]) + ], weights=[0.5, 0.5]), + parse_prompt("(\"mountain man\", \"a person with a hat (riding a bicycle.swap(skateboard))++\").and(0.5, 0.5)")) + + def test_badly_formed(self): + def make_untouched_prompt(prompt): + return Conjunction([FlattenedPrompt([(prompt, 1.0)])]) + + def assert_if_prompt_string_not_untouched(prompt): + self.assertEqual(make_untouched_prompt(prompt), parse_prompt(prompt)) + + assert_if_prompt_string_not_untouched('a test prompt') + # todo handle this + #assert_if_prompt_string_not_untouched('a badly formed +test prompt') + with self.assertRaises(pyparsing.ParseException): + parse_prompt('a badly (formed test prompt') + #with self.assertRaises(pyparsing.ParseException): + with self.assertRaises(pyparsing.ParseException): + parse_prompt('a badly (formed +test prompt') + with self.assertRaises(pyparsing.ParseException): + parse_prompt('a badly (formed +test )prompt') + with self.assertRaises(pyparsing.ParseException): + parse_prompt('a badly (formed +test )prompt') + with self.assertRaises(pyparsing.ParseException): + parse_prompt('(((a badly (formed +test )prompt') + with self.assertRaises(pyparsing.ParseException): + parse_prompt('(a (ba)dly (f)ormed +test prompt') + with self.assertRaises(pyparsing.ParseException): + parse_prompt('(a (ba)dly (f)ormed +test +prompt') + with self.assertRaises(pyparsing.ParseException): + parse_prompt('("((a badly (formed +test ").blend(1.0)') + + + def test_blend(self): + self.assertEqual(Conjunction( + [Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('fire flames', 1.0)])], [0.7, 0.3])]), + parse_prompt("(\"fire\", \"fire flames\").blend(0.7, 0.3)") + ) + self.assertEqual(Conjunction([Blend( + [FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('fire flames', 1.0)]), FlattenedPrompt([('hi', 1.0)])], + [0.7, 0.3, 1.0])]), + parse_prompt("(\"fire\", \"fire flames\", \"hi\").blend(0.7, 0.3, 1.0)") + ) + self.assertEqual(Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), + FlattenedPrompt([('fire flames', 1.0), ('hot', pow(1.1, 2))]), + FlattenedPrompt([('hi', 1.0)])], + weights=[0.7, 0.3, 1.0])]), + parse_prompt("(\"fire\", \"fire flames (hot)++\", \"hi\").blend(0.7, 0.3, 1.0)") + ) + # blend a single entry is not a failure + self.assertEqual(Conjunction([Blend([FlattenedPrompt([('fire', 1.0)])], [0.7])]), + parse_prompt("(\"fire\").blend(0.7)") + ) + # blend with empty + self.assertEqual( + Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]), + parse_prompt("(\"fire\", \"\").blend(0.7, 1)") + ) + self.assertEqual( + Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]), + parse_prompt("(\"fire\", \" \").blend(0.7, 1)") + ) + self.assertEqual( + Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]), + parse_prompt("(\"fire\", \" \").blend(0.7, 1)") + ) + self.assertEqual( + Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([(',', 1.0)])], [0.7, 1.0])]), + parse_prompt("(\"fire\", \" , \").blend(0.7, 1)") + ) + + self.assertEqual( + Conjunction([Blend([FlattenedPrompt([('mountain, man, hairy', 1)]), + FlattenedPrompt([('face, teeth,', 1), ('eyes', 0.9*0.9)])], weights=[1.0,-1.0])]), + parse_prompt('("mountain, man, hairy", "face, teeth, eyes--").blend(1,-1)') + ) + + + def test_nested(self): + self.assertEqual(make_weighted_conjunction([('fire', 1.0), ('flames', 2.0), ('trees', 3.0)]), + parse_prompt('fire (flames (trees)1.5)2.0')) + self.assertEqual(Conjunction([Blend(prompts=[FlattenedPrompt([('fire', 1.0), ('flames', 1.2100000000000002)]), + FlattenedPrompt([('mountain', 1.0), ('man', 2.0)])], + weights=[1.0, 1.0])]), + parse_prompt('("fire (flames)++", "mountain (man)2").blend(1,1)')) + + def test_cross_attention_control(self): + + self.assertEqual(Conjunction([ + FlattenedPrompt([Fragment('a', 1), + CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)]), + Fragment('eating a hotdog', 1)])]), parse_prompt("a \"cat\".swap(dog) eating a hotdog")) + + self.assertEqual(Conjunction([ + FlattenedPrompt([Fragment('a', 1), + CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)]), + Fragment('eating a hotdog', 1)])]), parse_prompt("a cat.swap(dog) eating a hotdog")) + + + fire_flames_to_trees = Conjunction([FlattenedPrompt([('fire', 1.0), \ + CrossAttentionControlSubstitute([Fragment('flames', 1)], [Fragment('trees', 1)])])]) + self.assertEqual(fire_flames_to_trees, parse_prompt('fire "flames".swap(trees)')) + self.assertEqual(fire_flames_to_trees, parse_prompt('fire (flames).swap(trees)')) + self.assertEqual(fire_flames_to_trees, parse_prompt('fire ("flames").swap(trees)')) + self.assertEqual(fire_flames_to_trees, parse_prompt('fire "flames".swap("trees")')) + self.assertEqual(fire_flames_to_trees, parse_prompt('fire (flames).swap("trees")')) + self.assertEqual(fire_flames_to_trees, parse_prompt('fire ("flames").swap("trees")')) + + fire_flames_to_trees_and_houses = Conjunction([FlattenedPrompt([('fire', 1.0), \ + CrossAttentionControlSubstitute([Fragment('flames', 1)], [Fragment('trees and houses', 1)])])]) + self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire ("flames").swap("trees and houses")')) + self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire (flames).swap("trees and houses")')) + self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire "flames".swap("trees and houses")')) + + trees_and_houses_to_flames = Conjunction([FlattenedPrompt([('fire', 1.0), \ + CrossAttentionControlSubstitute([Fragment('trees and houses', 1)], [Fragment('flames',1)])])]) + self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire ("trees and houses").swap("flames")')) + self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire (trees and houses).swap("flames")')) + self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire "trees and houses".swap("flames")')) + self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire ("trees and houses").swap(flames)')) + self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire (trees and houses).swap(flames)')) + self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire "trees and houses".swap(flames)')) + + flames_to_trees_fire = Conjunction([FlattenedPrompt([ + CrossAttentionControlSubstitute([Fragment('flames',1)], [Fragment('trees',1)]), + (', fire', 1.0)])]) + self.assertEqual(flames_to_trees_fire, parse_prompt('"flames".swap("trees"), fire')) + self.assertEqual(flames_to_trees_fire, parse_prompt('(flames).swap("trees"), fire')) + self.assertEqual(flames_to_trees_fire, parse_prompt('("flames").swap("trees"), fire')) + self.assertEqual(flames_to_trees_fire, parse_prompt('"flames".swap(trees), fire')) + self.assertEqual(flames_to_trees_fire, parse_prompt('(flames).swap(trees), fire ')) + self.assertEqual(flames_to_trees_fire, parse_prompt('("flames").swap(trees), fire ')) + + + self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1), + CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]), + parse_prompt('a forest landscape "".swap("in winter")')) + self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1), + CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]), + parse_prompt('a forest landscape " ".swap("in winter")')) + + self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1), + CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]), + parse_prompt('a forest landscape "in winter".swap("")')) + self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1), + CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]), + parse_prompt('a forest landscape "in winter".swap()')) + self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1), + CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]), + parse_prompt('a forest landscape "in winter".swap(" ")')) + + def test_cross_attention_control_with_attention(self): + flames_to_trees_fire = Conjunction([FlattenedPrompt([ + CrossAttentionControlSubstitute([Fragment('flames',0.5)], [Fragment('trees',0.7)]), + Fragment(',', 1), Fragment('fire', 2.0)])]) + self.assertEqual(flames_to_trees_fire, parse_prompt('"(flames)0.5".swap("(trees)0.7"), (fire)2.0')) + flames_to_trees_fire = Conjunction([FlattenedPrompt([ + CrossAttentionControlSubstitute([Fragment('fire',0.5), Fragment('flames',0.25)], [Fragment('trees',0.7)]), + Fragment(',', 1), Fragment('fire', 2.0)])]) + self.assertEqual(flames_to_trees_fire, parse_prompt('"(fire (flames)0.5)0.5".swap("(trees)0.7"), (fire)2.0')) + flames_to_trees_fire = Conjunction([FlattenedPrompt([ + CrossAttentionControlSubstitute([Fragment('fire',0.5), Fragment('flames',0.25)], [Fragment('trees',0.7), Fragment('houses', 1)]), + Fragment(',', 1), Fragment('fire', 2.0)])]) + self.assertEqual(flames_to_trees_fire, parse_prompt('"(fire (flames)0.5)0.5".swap("(trees)0.7 houses"), (fire)2.0')) + + def test_cross_attention_control_options(self): + self.assertEqual(Conjunction([ + FlattenedPrompt([Fragment('a', 1), + CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)], options={'s_start':0.1}), + Fragment('eating a hotdog', 1)])]), + parse_prompt("a \"cat\".swap(dog, s_start=0.1) eating a hotdog")) + self.assertEqual(Conjunction([ + FlattenedPrompt([Fragment('a', 1), + CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)], options={'t_start':0.1}), + Fragment('eating a hotdog', 1)])]), + parse_prompt("a \"cat\".swap(dog, t_start=0.1) eating a hotdog")) + self.assertEqual(Conjunction([ + FlattenedPrompt([Fragment('a', 1), + CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)], options={'s_start': 20.0, 't_start':0.1}), + Fragment('eating a hotdog', 1)])]), + parse_prompt("a \"cat\".swap(dog, t_start=0.1, s_start=20) eating a hotdog")) + + self.assertEqual( + Conjunction([ + FlattenedPrompt([Fragment('a fantasy forest landscape', 1), + CrossAttentionControlSubstitute([Fragment('', 1)], [Fragment('with a river', 1)], + options={'s_start': 0.8, 't_start': 0.8})])]), + parse_prompt("a fantasy forest landscape \"\".swap(with a river, s_start=0.8, t_start=0.8)")) + + + def test_escaping(self): + + # make sure ", ( and ) can be escaped + + self.assertEqual(make_basic_conjunction(['mountain (man)']),parse_prompt('mountain \(man\)')) + self.assertEqual(make_basic_conjunction(['mountain (man )']),parse_prompt('mountain (\(man)\)')) + self.assertEqual(make_basic_conjunction(['mountain (man)']),parse_prompt('mountain (\(man\))')) + self.assertEqual(make_weighted_conjunction([('mountain', 1), ('(man)', 1.1)]), parse_prompt('mountain (\(man\))+')) + self.assertEqual(make_weighted_conjunction([('mountain', 1), ('(man)', 1.1)]), parse_prompt('"mountain" (\(man\))+')) + self.assertEqual(make_weighted_conjunction([('"mountain"', 1), ('(man)', 1.1)]), parse_prompt('\\"mountain\\" (\(man\))+')) + # same weights for each are combined into one + self.assertEqual(make_weighted_conjunction([('"mountain" (man)', 1.1)]), parse_prompt('(\\"mountain\\")+ (\(man\))+')) + self.assertEqual(make_weighted_conjunction([('"mountain"', 1.1), ('(man)', 0.9)]), parse_prompt('(\\"mountain\\")+ (\(man\))-')) + + self.assertEqual(make_weighted_conjunction([('mountain', 1), ('\(man\)', 1.1)]),parse_prompt('mountain (\(man\))1.1')) + self.assertEqual(make_weighted_conjunction([('mountain', 1), ('\(man\)', 1.1)]),parse_prompt('"mountain" (\(man\))1.1')) + self.assertEqual(make_weighted_conjunction([('"mountain"', 1), ('\(man\)', 1.1)]),parse_prompt('\\"mountain\\" (\(man\))1.1')) + # same weights for each are combined into one + self.assertEqual(make_weighted_conjunction([('\\"mountain\\" \(man\)', 1.1)]),parse_prompt('(\\"mountain\\")+ (\(man\))1.1')) + self.assertEqual(make_weighted_conjunction([('\\"mountain\\"', 1.1), ('\(man\)', 0.9)]),parse_prompt('(\\"mountain\\")1.1 (\(man\))0.9')) + + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain', 1.1), ('\(man\)', 1.1*1.1)]),parse_prompt('hairy (mountain (\(man\))+)+')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('\(man\)', 1.1*1.1), ('mountain', 1.1)]),parse_prompt('hairy ((\(man\))1.1 "mountain")+')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain', 1.1), ('\(man\)', 1.1*1.1)]),parse_prompt('hairy ("mountain" (\(man\))1.1 )+')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, man', 1.1)]),parse_prompt('hairy ("mountain, man")+')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, man with a', 1.1), ('beard', 1.1*1.1)]), parse_prompt('hairy ("mountain, man" with a beard+)+')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, man with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, man" with a (beard)2.0)+')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man\" with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\"man\\"" with a (beard)2.0)+')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, m\"an\" with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, m\\"an\\"" with a (beard)2.0)+')) + + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man (with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" \(with a (beard)2.0)+')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man w(ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" w\(ith a (beard)2.0)+')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man with( a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" with\( a (beard)2.0)+')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man )with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" \)with a (beard)2.0)+')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" w\)ith a (beard)2.0)+')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man with) a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" with\) a (beard)2.0)+')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mou)ntain, \"man (wit(h a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mou\)ntain, \\\"man\" \(wit\(h a (beard)2.0)+')) + self.assertEqual(make_weighted_conjunction([('hai(ry', 1), ('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hai\(ry ("mountain, \\\"man\" w\)ith a (beard)2.0)+')) + self.assertEqual(make_weighted_conjunction([('hairy((', 1), ('mountain, \"man with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy\(\( ("mountain, \\\"man\" with a (beard)2.0)+')) + + self.assertEqual(make_weighted_conjunction([('mountain, \"man (with a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" \(with a (beard)2.0)+ hairy')) + self.assertEqual(make_weighted_conjunction([('mountain, \"man w(ith a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" w\(ith a (beard)2.0)+hairy')) + self.assertEqual(make_weighted_conjunction([('mountain, \"man with( a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" with\( a (beard)2.0)+ hairy')) + self.assertEqual(make_weighted_conjunction([('mountain, \"man )with a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" \)with a (beard)2.0)+ hairy')) + self.assertEqual(make_weighted_conjunction([('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" w\)ith a (beard)2.0)+ hairy')) + self.assertEqual(make_weighted_conjunction([('mountain, \"man with) a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt(' ("mountain, \\\"man\" with\) a (beard)2.0)+ hairy')) + self.assertEqual(make_weighted_conjunction([('mou)ntain, \"man (wit(h a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mou\)ntain, \\\"man\" \(wit\(h a (beard)2.0)+ hairy')) + self.assertEqual(make_weighted_conjunction([('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0), ('hai(ry', 1)]), parse_prompt('("mountain, \\\"man\" w\)ith a (beard)2.0)+ hai\(ry ')) + self.assertEqual(make_weighted_conjunction([('mountain, \"man with a', 1.1), ('beard', 1.1*2.0), ('hairy((', 1)]), parse_prompt('("mountain, \\\"man\" with a (beard)2.0)+ hairy\(\( ')) + + def test_cross_attention_escaping(self): + + self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('monkey', 1)])])]), + parse_prompt('mountain (man).swap(monkey)')) + self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('m(onkey', 1)])])]), + parse_prompt('mountain (man).swap(m\(onkey)')) + self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('m(an', 1)], [Fragment('m(onkey', 1)])])]), + parse_prompt('mountain (m\(an).swap(m\(onkey)')) + self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('(((', 1)], [Fragment('m(on))key', 1)])])]), + parse_prompt('mountain (\(\(\().swap(m\(on\)\)key)')) + + self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('monkey', 1)])])]), + parse_prompt('mountain ("man").swap(monkey)')) + self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('monkey', 1)])])]), + parse_prompt('mountain ("man").swap("monkey")')) + self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('"man', 1)], [Fragment('monkey', 1)])])]), + parse_prompt('mountain (\\"man).swap("monkey")')) + self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('m(onkey', 1)])])]), + parse_prompt('mountain (man).swap(m\(onkey)')) + self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('m(an', 1)], [Fragment('m(onkey', 1)])])]), + parse_prompt('mountain (m\(an).swap(m\(onkey)')) + self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('(((', 1)], [Fragment('m(on))key', 1)])])]), + parse_prompt('mountain (\(\(\().swap(m\(on\)\)key)')) + + def test_legacy_blend(self): + pp = PromptParser() + + self.assertEqual(Blend([FlattenedPrompt([('mountain man', 1)]), + FlattenedPrompt([('man mountain', 1)])], + weights=[0.5,0.5]), + pp.parse_legacy_blend('mountain man:1 man mountain:1')) + + self.assertEqual(Blend([FlattenedPrompt([('mountain', 1.1), ('man', 1)]), + FlattenedPrompt([('man', 1), ('mountain', 0.9)])], + weights=[0.5,0.5]), + pp.parse_legacy_blend('mountain+ man:1 man mountain-:1')) + + self.assertEqual(Blend([FlattenedPrompt([('mountain', 1.1), ('man', 1)]), + FlattenedPrompt([('man', 1), ('mountain', 0.9)])], + weights=[0.5,0.5]), + pp.parse_legacy_blend('mountain+ man:1 man mountain-')) + + self.assertEqual(Blend([FlattenedPrompt([('mountain', 1.1), ('man', 1)]), + FlattenedPrompt([('man', 1), ('mountain', 0.9)])], + weights=[0.5,0.5]), + pp.parse_legacy_blend('mountain+ man: man mountain-:')) + + self.assertEqual(Blend([FlattenedPrompt([('mountain man', 1)]), + FlattenedPrompt([('man mountain', 1)])], + weights=[0.75,0.25]), + pp.parse_legacy_blend('mountain man:3 man mountain:1')) + + self.assertEqual(Blend([FlattenedPrompt([('mountain man', 1)]), + FlattenedPrompt([('man mountain', 1)])], + weights=[1.0,0.0]), + pp.parse_legacy_blend('mountain man:3 man mountain:0')) + + self.assertEqual(Blend([FlattenedPrompt([('mountain man', 1)]), + FlattenedPrompt([('man mountain', 1)])], + weights=[0.8,0.2]), + pp.parse_legacy_blend('"mountain man":4 man mountain')) + + + def test_single(self): + # todo handle this + #self.assertEqual(make_basic_conjunction(['a badly formed +test prompt']), + # parse_prompt('a badly formed +test prompt')) + pass + + +if __name__ == '__main__': + unittest.main()