Merge branch 'fix-prompts' of https://github.com/damian0815/InvokeAI into merge-prompt-and-inpaint-model

This commit is contained in:
Lincoln Stein 2022-10-26 08:50:55 -04:00
commit 2f1c1e7695
22 changed files with 2077 additions and 173 deletions

View File

@ -14,7 +14,7 @@ from threading import Event
from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata
from ldm.invoke.conditioning import split_weighted_subprompts
from ldm.invoke.prompt_parser import split_weighted_subprompts
from backend.modules.parameters import parameters_to_command

View File

@ -33,7 +33,7 @@ from ldm.generate import Generate
from ldm.invoke.restoration import Restoration
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata
from ldm.invoke.args import APP_ID, APP_VERSION, calculate_init_img_hash
from ldm.invoke.conditioning import split_weighted_subprompts
from ldm.invoke.prompt_parser import split_weighted_subprompts
from modules.parameters import parameters_to_command

View File

@ -76,4 +76,4 @@ model:
target: torch.nn.Identity
cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
target: ldm.modules.encoders.modules.WeightedFrozenCLIPEmbedder

View File

@ -84,6 +84,48 @@ Getting close - but there's no sense in having a saddle when our horse doesn't h
---
## **Prompt Syntax Features**
The InvokeAI prompting language has the following features:
### Attention weighting
Append a word or phrase with `-` or `+`, or a weight between `0` and `2` (`1`=default), to decrease or increase "attention" (= a mix of per-token CFG weighting multiplier and, for `-`, a weighted blend with the prompt without the term).
The following will be recognised:
* single words without parentheses: `a tall thin man picking apricots+`
* single or multiple words with parentheses: `a tall thin man picking (apricots)+` `a tall thin man picking (apricots)-` `a tall thin man (picking apricots)+` `a tall thin man (picking apricots)-`
* more effect with more symbols `a tall thin man (picking apricots)++`
* nesting `a tall thin man (picking apricots+)++` (`apricots` effectively gets `+++`)
* all of the above with explicit numbers `a tall thin man picking (apricots)1.1` `a tall thin man (picking (apricots)1.3)1.1`. (`+` is equivalent to 1.1, `++` is pow(1.1,2), `+++` is pow(1.1,3), etc; `-` means 0.9, `--` means pow(0.9,2), etc.)
* attention also applies to `[unconditioning]` so `a tall thin man picking apricots [(ladder)0.01]` will *very gently* nudge SD away from trying to draw the man on a ladder
### Blending between prompts
* `("a tall thin man picking apricots", "a tall thin man picking pears").blend(1,1)`
* The existing prompt blending using `:<weight>` will continue to be supported - `("a tall thin man picking apricots", "a tall thin man picking pears").blend(1,1)` is equivalent to `a tall thin man picking apricots:1 a tall thin man picking pears:1` in the old syntax.
* Attention weights can be nested inside blends.
* Non-normalized blends are supported by passing `no_normalize` as an additional argument to the blend weights, eg `("a tall thin man picking apricots", "a tall thin man picking pears").blend(1,-1,no_normalize)`. very fun to explore local maxima in the feature space, but also easy to produce garbage output.
See the section below on "Prompt Blending" for more information about how this works.
### Cross-Attention Control ('prompt2prompt')
Denoise with a given prompt and then re-use the attention→pixel maps to substitute words in the original prompt for words in a new prompt. Based off [bloc97's colab](https://github.com/bloc97/CrossAttentionControl).
* `a ("fluffy cat").swap("smiling dog") eating a hotdog`.
* quotes optional: `a (fluffy cat).swap(smiling dog) eating a hotdog`.
* for single word substitutions parentheses are also optional: `a cat.swap(dog) eating a hotdog`.
* Supports options `s_start`, `s_end`, `t_start`, `t_end` (each 0-1) loosely corresponding to bloc97's `prompt_edit_spatial_start/_end` and `prompt_edit_tokens_start/_end` but with the math swapped to make it easier to intuitively understand.
* Example usage:`a (cat).swap(dog, s_end=0.3) eating a hotdog` - the `s_end` argument means that the "spatial" (self-attention) edit will stop having any effect after 30% (=0.3) of the steps have been done, leaving Stable Diffusion with 70% of the steps where it is free to decide for itself how to reshape the cat-form into a dog form.
* The numbers represent a percentage through the step sequence where the edits should happen. 0 means the start (noisy starting image), 1 is the end (final image).
* For img2img, the step sequence does not start at 0 but instead at (1-strength) - so if strength is 0.7, s_start and s_end must both be greater than 0.3 (1-0.7) to have any effect.
* Convenience option `shape_freedom` (0-1) to specify how much "freedom" Stable Diffusion should have to change the shape of the subject being swapped.
* `a (cat).swap(dog, shape_freedom=0.5) eating a hotdog`.
### Escaping parantheses () and speech marks ""
If the model you are using has parentheses () or speech marks "" as part of its syntax, you will need to "escape" these using a backslash, so that`(my_keyword)` becomes `\(my_keyword\)`. Otherwise, the prompt parser will attempt to interpret the parentheses as part of the prompt syntax and it will get confused.
## **Prompt Blending**
You may blend together different sections of the prompt to explore the

View File

@ -1,5 +1,5 @@
# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
import pyparsing
# Derived from source code carrying the following copyrights
# Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich
# Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors
@ -24,6 +24,7 @@ from PIL import Image, ImageOps
from torch import nn
from pytorch_lightning import seed_everything, logging
from ldm.invoke.prompt_parser import PromptParser
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
@ -32,7 +33,7 @@ from ldm.invoke.pngwriter import PngWriter
from ldm.invoke.args import metadata_from_png
from ldm.invoke.image_util import InitImageResizer
from ldm.invoke.devices import choose_torch_device, choose_precision
from ldm.invoke.conditioning import get_uc_and_c
from ldm.invoke.conditioning import get_uc_and_c_and_ec
from ldm.invoke.model_cache import ModelCache
from ldm.invoke.seamless import configure_model_padding
from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale
@ -404,7 +405,7 @@ class Generate:
mask_image = None
try:
uc, c = get_uc_and_c(
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
prompt, model =self.model,
skip_normalize=skip_normalize,
log_tokens =self.log_tokenization
@ -448,7 +449,7 @@ class Generate:
sampler=self.sampler,
steps=steps,
cfg_scale=cfg_scale,
conditioning=(uc, c),
conditioning=(uc, c, extra_conditioning_info),
ddim_eta=ddim_eta,
image_callback=image_callback, # called after the final image is generated
step_callback=step_callback, # called after each intermediate image is generated
@ -481,14 +482,14 @@ class Generate:
save_original = save_original,
image_callback = image_callback)
except RuntimeError as e:
print(traceback.format_exc(), file=sys.stderr)
print('>> Could not generate image.')
except KeyboardInterrupt:
if catch_interrupts:
print('**Interrupted** Partial results will be returned.')
else:
raise KeyboardInterrupt
except RuntimeError as e:
print(traceback.format_exc(), file=sys.stderr)
print('>> Could not generate image.')
toc = time.time()
print('>> Usage stats:')
@ -553,7 +554,8 @@ class Generate:
image = Image.open(image_path)
# used by multiple postfixers
uc, c = get_uc_and_c(
# todo: cross-attention control
uc, c, _ = get_uc_and_c_and_ec(
prompt, model =self.model,
skip_normalize=opt.skip_normalize,
log_tokens =opt.log_tokenization

View File

@ -92,7 +92,7 @@ import copy
import base64
import functools
import ldm.invoke.pngwriter
from ldm.invoke.conditioning import split_weighted_subprompts
from ldm.invoke.prompt_parser import split_weighted_subprompts
SAMPLER_CHOICES = [
'ddim',

View File

@ -4,107 +4,166 @@ weighted subprompts.
Useful function exports:
get_uc_and_c() get the conditioned and unconditioned latent
get_uc_and_c_and_ec() get the conditioned and unconditioned latent, and edited conditioning if we're doing cross-attention control
split_weighted_subpromopts() split subprompts, normalize and weight them
log_tokenization() print out colour-coded tokens and warn if truncated
'''
import re
from difflib import SequenceMatcher
from typing import Union
import torch
def get_uc_and_c(prompt, model, log_tokens=False, skip_normalize=False):
from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \
CrossAttentionControlledFragment, CrossAttentionControlSubstitute, CrossAttentionControlAppend, Fragment
from ..models.diffusion.cross_attention_control import CrossAttentionControl
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder
def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_normalize=False):
# Extract Unconditioned Words From Prompt
unconditioned_words = ''
unconditional_regex = r'\[(.*?)\]'
unconditionals = re.findall(unconditional_regex, prompt)
unconditionals = re.findall(unconditional_regex, prompt_string_uncleaned)
if len(unconditionals) > 0:
unconditioned_words = ' '.join(unconditionals)
# Remove Unconditioned Words From Prompt
unconditional_regex_compile = re.compile(unconditional_regex)
clean_prompt = unconditional_regex_compile.sub(' ', prompt)
prompt = re.sub(' +', ' ', clean_prompt)
clean_prompt = unconditional_regex_compile.sub(' ', prompt_string_uncleaned)
prompt_string_cleaned = re.sub(' +', ' ', clean_prompt)
else:
prompt_string_cleaned = prompt_string_uncleaned
uc = model.get_learned_conditioning([unconditioned_words])
pp = PromptParser()
# get weighted sub-prompts
weighted_subprompts = split_weighted_subprompts(
prompt, skip_normalize
parsed_prompt: Union[FlattenedPrompt, Blend] = None
legacy_blend: Blend = pp.parse_legacy_blend(prompt_string_cleaned)
if legacy_blend is not None:
parsed_prompt = legacy_blend
else:
# we don't support conjunctions for now
parsed_prompt = pp.parse_conjunction(prompt_string_cleaned).prompts[0]
parsed_negative_prompt: FlattenedPrompt = pp.parse_conjunction(unconditioned_words).prompts[0]
print(f">> Parsed prompt to {parsed_prompt}")
conditioning = None
cac_args:CrossAttentionControl.Arguments = None
if type(parsed_prompt) is Blend:
blend: Blend = parsed_prompt
embeddings_to_blend = None
for flattened_prompt in blend.prompts:
this_embedding, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt, log_tokens=log_tokens)
embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat(
(embeddings_to_blend, this_embedding))
conditioning = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0),
blend.weights,
normalize=blend.normalize_weights)
else:
flattened_prompt: FlattenedPrompt = parsed_prompt
wants_cross_attention_control = type(flattened_prompt) is not Blend \
and any([issubclass(type(x), CrossAttentionControlledFragment) for x in flattened_prompt.children])
if wants_cross_attention_control:
original_prompt = FlattenedPrompt()
edited_prompt = FlattenedPrompt()
# for name, a0, a1, b0, b1 in edit_opcodes: only name == 'equal' is currently parsed
original_token_count = 0
edited_token_count = 0
edit_opcodes = []
edit_options = []
for fragment in flattened_prompt.children:
if type(fragment) is CrossAttentionControlSubstitute:
original_prompt.append(fragment.original)
edited_prompt.append(fragment.edited)
to_replace_token_count = get_tokens_length(model, fragment.original)
replacement_token_count = get_tokens_length(model, fragment.edited)
edit_opcodes.append(('replace',
original_token_count, original_token_count + to_replace_token_count,
edited_token_count, edited_token_count + replacement_token_count
))
original_token_count += to_replace_token_count
edited_token_count += replacement_token_count
edit_options.append(fragment.options)
#elif type(fragment) is CrossAttentionControlAppend:
# edited_prompt.append(fragment.fragment)
else:
# regular fragment
original_prompt.append(fragment)
edited_prompt.append(fragment)
count = get_tokens_length(model, [fragment])
edit_opcodes.append(('equal', original_token_count, original_token_count+count, edited_token_count, edited_token_count+count))
edit_options.append(None)
original_token_count += count
edited_token_count += count
original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, original_prompt, log_tokens=log_tokens)
# naïvely building a single edited_embeddings like this disregards the effects of changing the absolute location of
# subsequent tokens when there is >1 edit and earlier edits change the total token count.
# eg "a cat.swap(smiling dog, s_start=0.5) eating a hotdog.swap(pizza)" - when the 'pizza' edit is active but the
# 'cat' edit is not, the 'pizza' feature vector will nevertheless be affected by the introduction of the extra
# token 'smiling' in the inactive 'cat' edit.
# todo: build multiple edited_embeddings, one for each edit, and pass just the edited fragments through to the CrossAttentionControl functions
edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, edited_prompt, log_tokens=log_tokens)
conditioning = original_embeddings
edited_conditioning = edited_embeddings
print('got edit_opcodes', edit_opcodes, 'options', edit_options)
cac_args = CrossAttentionControl.Arguments(
edited_conditioning = edited_conditioning,
edit_opcodes = edit_opcodes,
edit_options = edit_options
)
else:
conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt, log_tokens=log_tokens)
unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, parsed_negative_prompt, log_tokens=log_tokens)
return (
unconditioning, conditioning, InvokeAIDiffuserComponent.ExtraConditioningInfo(
cross_attention_control_args=cac_args
)
)
if len(weighted_subprompts) > 1:
# i dont know if this is correct.. but it works
c = torch.zeros_like(uc)
# normalize each "sub prompt" and add it
for subprompt, weight in weighted_subprompts:
log_tokenization(subprompt, model, log_tokens, weight)
c = torch.add(
c,
model.get_learned_conditioning([subprompt]),
alpha=weight,
)
else: # just standard 1 prompt
log_tokenization(prompt, model, log_tokens, 1)
c = model.get_learned_conditioning([prompt])
uc = model.get_learned_conditioning([unconditioned_words])
return (uc, c)
def split_weighted_subprompts(text, skip_normalize=False)->list:
"""
grabs all text up to the first occurrence of ':'
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
if ':' has no value defined, defaults to 1.0
repeats until no text remaining
"""
prompt_parser = re.compile("""
(?P<prompt> # capture group for 'prompt'
(?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:'
) # end 'prompt'
(?: # non-capture group
:+ # match one or more ':' characters
(?P<weight> # capture group for 'weight'
-?\d+(?:\.\d+)? # match positive or negative integer or decimal number
)? # end weight capture group, make optional
\s* # strip spaces after weight
| # OR
$ # else, if no ':' then match end of line
) # end non-capture group
""", re.VERBOSE)
parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float(
match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)]
if skip_normalize:
return parsed_prompts
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
if weight_sum == 0:
print(
"Warning: Subprompt weights add up to zero. Discarding and using even weights instead.")
equal_weight = 1 / max(len(parsed_prompts), 1)
return [(x[0], equal_weight) for x in parsed_prompts]
return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
def build_token_edit_opcodes(original_tokens, edited_tokens):
original_tokens = original_tokens.cpu().numpy()[0]
edited_tokens = edited_tokens.cpu().numpy()[0]
return SequenceMatcher(None, original_tokens, edited_tokens).get_opcodes()
def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt, log_tokens: bool=False):
if type(flattened_prompt) is not FlattenedPrompt:
raise Exception(f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead")
fragments = [x.text for x in flattened_prompt.children]
weights = [x.weight for x in flattened_prompt.children]
embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True, fragment_weights=[weights])
if not flattened_prompt.is_empty and log_tokens:
start_token = model.cond_stage_model.tokenizer.bos_token_id
end_token = model.cond_stage_model.tokenizer.eos_token_id
tokens_list = tokens[0].tolist()
if tokens_list[0] == start_token:
tokens_list[0] = '<start>'
try:
first_end_token_index = tokens_list.index(end_token)
tokens_list[first_end_token_index] = '<end>'
tokens_list = tokens_list[:first_end_token_index+1]
except ValueError:
pass
print(f">> Prompt fragments {fragments}, tokenized to \n{tokens_list}")
return embeddings, tokens
def get_tokens_length(model, fragments: list[Fragment]):
fragment_texts = [x.text for x in fragments]
tokens = model.cond_stage_model.get_tokens(fragment_texts, include_start_and_end_markers=False)
return sum([len(x) for x in tokens])
# shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' '
def log_tokenization(text, model, log=False, weight=1):
if not log:
return
tokens = model.cond_stage_model.tokenizer._tokenize(text)
tokenized = ""
discarded = ""
usedTokens = 0
totalTokens = len(tokens)
for i in range(0, totalTokens):
token = tokens[i].replace('</w>', ' ')
# alternate color
s = (usedTokens % 6) + 1
if i < model.cond_stage_model.max_length:
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
usedTokens += 1
else: # over max token length
discarded = discarded + f"\x1b[0;3{s};40m{token}"
print(f"\n>> Tokens ({usedTokens}), Weight ({weight:.2f}):\n{tokenized}\x1b[0m")
if discarded != "":
print(
f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m"
)

View File

@ -10,6 +10,7 @@ from PIL import Image
from ldm.invoke.devices import choose_autocast
from ldm.invoke.generator.base import Generator
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
class Img2Img(Generator):
def __init__(self, model, precision):
@ -38,7 +39,7 @@ class Img2Img(Generator):
) # move to latent space
t_enc = int(strength * steps)
uc, c = conditioning
uc, c, extra_conditioning_info = conditioning
def make_image(x_T):
# encode (scaled latent)
@ -56,6 +57,8 @@ class Img2Img(Generator):
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,
init_latent = self.init_latent, # changes how noising is performed in ksampler
extra_conditioning_info = extra_conditioning_info,
all_timesteps_count = steps
)
return self.sample_to_image(samples)

View File

@ -73,7 +73,8 @@ class Inpaint(Img2Img):
) # move to latent space
t_enc = int(strength * steps)
uc, c = conditioning
# todo: support cross-attention control
uc, c, _ = conditioning
print(f">> target t_enc is {t_enc} steps")

View File

@ -5,6 +5,8 @@ ldm.invoke.generator.txt2img inherits from ldm.invoke.generator
import torch
import numpy as np
from ldm.invoke.generator.base import Generator
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
class Txt2Img(Generator):
def __init__(self, model, precision):
@ -19,7 +21,7 @@ class Txt2Img(Generator):
kwargs are 'width' and 'height'
"""
self.perlin = perlin
uc, c = conditioning
uc, c, extra_conditioning_info = conditioning
@torch.no_grad()
def make_image(x_T):
@ -43,6 +45,7 @@ class Txt2Img(Generator):
verbose = False,
unconditional_guidance_scale = cfg_scale,
unconditional_conditioning = uc,
extra_conditioning_info = extra_conditioning_info,
eta = ddim_eta,
img_callback = step_callback,
threshold = threshold,

View File

@ -7,6 +7,7 @@ import numpy as np
import math
from ldm.invoke.generator.base import Generator
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
class Txt2Img2Img(Generator):
@ -22,7 +23,7 @@ class Txt2Img2Img(Generator):
Return value depends on the seed at the time you call it
kwargs are 'width' and 'height'
"""
uc, c = conditioning
uc, c, extra_conditioning_info = conditioning
@torch.no_grad()
def make_image(x_T):
@ -60,7 +61,8 @@ class Txt2Img2Img(Generator):
unconditional_guidance_scale = cfg_scale,
unconditional_conditioning = uc,
eta = ddim_eta,
img_callback = step_callback
img_callback = step_callback,
extra_conditioning_info = extra_conditioning_info
)
print(
@ -94,6 +96,8 @@ class Txt2Img2Img(Generator):
img_callback = step_callback,
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,
extra_conditioning_info=extra_conditioning_info,
all_timesteps_count=steps
)
if self.free_gpu_mem:

680
ldm/invoke/prompt_parser.py Normal file
View File

@ -0,0 +1,680 @@
import string
from typing import Union, Optional
import re
import pyparsing as pp
class Prompt():
"""
Mid-level structure for storing the tree-like result of parsing a prompt. A Prompt may not represent the whole of
the singular user-defined "prompt string" (although it can) - for example, if the user specifies a Blend, the objects
that are to be blended together are stored individuall as Prompt objects.
Nesting makes this object not suitable for directly tokenizing; instead call flatten() on the containing Conjunction
to produce a FlattenedPrompt.
"""
def __init__(self, parts: list):
for c in parts:
if type(c) is not Attention and not issubclass(type(c), BaseFragment) and type(c) is not pp.ParseResults:
raise PromptParser.ParsingException(f"Prompt cannot contain {type(c).__name__} {c}, only {BaseFragment.__subclasses__()} are allowed")
self.children = parts
def __repr__(self):
return f"Prompt:{self.children}"
def __eq__(self, other):
return type(other) is Prompt and other.children == self.children
class BaseFragment:
pass
class FlattenedPrompt():
"""
A Prompt that has been passed through flatten(). Its children can be readily tokenized.
"""
def __init__(self, parts: list=[]):
self.children = []
for part in parts:
self.append(part)
def append(self, fragment: Union[list, BaseFragment, tuple]):
# verify type correctness
if type(fragment) is list:
for x in fragment:
self.append(x)
elif issubclass(type(fragment), BaseFragment):
self.children.append(fragment)
elif type(fragment) is tuple:
# upgrade tuples to Fragments
if type(fragment[0]) is not str or (type(fragment[1]) is not float and type(fragment[1]) is not int):
raise PromptParser.ParsingException(
f"FlattenedPrompt cannot contain {fragment}, only Fragments or (str, float) tuples are allowed")
self.children.append(Fragment(fragment[0], fragment[1]))
else:
raise PromptParser.ParsingException(
f"FlattenedPrompt cannot contain {fragment}, only Fragments or (str, float) tuples are allowed")
@property
def is_empty(self):
return len(self.children) == 0 or \
(len(self.children) == 1 and len(self.children[0].text) == 0)
def __repr__(self):
return f"FlattenedPrompt:{self.children}"
def __eq__(self, other):
return type(other) is FlattenedPrompt and other.children == self.children
class Fragment(BaseFragment):
"""
A Fragment is a chunk of plain text and an optional weight. The text should be passed as-is to the CLIP tokenizer.
"""
def __init__(self, text: str, weight: float=1):
assert(type(text) is str)
if '\\"' in text or '\\(' in text or '\\)' in text:
#print("Fragment converting escaped \( \) \\\" into ( ) \"")
text = text.replace('\\(', '(').replace('\\)', ')').replace('\\"', '"')
self.text = text
self.weight = float(weight)
def __repr__(self):
return "Fragment:'"+self.text+"'@"+str(self.weight)
def __eq__(self, other):
return type(other) is Fragment \
and other.text == self.text \
and other.weight == self.weight
class Attention():
"""
Nestable weight control for fragments. Each object in the children array may in turn be an Attention object;
weights should be considered to accumulate as the tree is traversed to deeper levels of nesting.
Do not traverse directly; instead obtain a FlattenedPrompt by calling Flatten() on a top-level Conjunction object.
"""
def __init__(self, weight: float, children: list):
self.weight = weight
self.children = children
#print(f"A: requested attention '{children}' to {weight}")
def __repr__(self):
return f"Attention:'{self.children}' @ {self.weight}"
def __eq__(self, other):
return type(other) is Attention and other.weight == self.weight and other.fragment == self.fragment
class CrossAttentionControlledFragment(BaseFragment):
pass
class CrossAttentionControlSubstitute(CrossAttentionControlledFragment):
"""
A Cross-Attention Controlled ('prompt2prompt') fragment, for use inside a Prompt, Attention, or FlattenedPrompt.
Representing an "original" word sequence that supplies feature vectors for an initial diffusion operation, and an
"edited" word sequence, to which the attention maps produced by the "original" word sequence are applied. Intuitively,
the result should be an "edited" image that looks like the "original" image with concepts swapped.
eg "a cat sitting on a car" (original) -> "a smiling dog sitting on a car" (edited): the edited image should look
almost exactly the same as the original, but with a smiling dog rendered in place of the cat. The
CrossAttentionControlSubstitute object representing this swap may be confined to the tokens being swapped:
CrossAttentionControlSubstitute(original=[Fragment('cat')], edited=[Fragment('dog')])
or it may represent a larger portion of the token sequence:
CrossAttentionControlSubstitute(original=[Fragment('a cat sitting on a car')],
edited=[Fragment('a smiling dog sitting on a car')])
In either case expect it to be embedded in a Prompt or FlattenedPrompt:
FlattenedPrompt([
Fragment('a'),
CrossAttentionControlSubstitute(original=[Fragment('cat')], edited=[Fragment('dog')]),
Fragment('sitting on a car')
])
"""
def __init__(self, original: Union[Fragment, list], edited: Union[Fragment, list], options: dict=None):
self.original = original
self.edited = edited
default_options = {
's_start': 0.0,
's_end': 0.206, # ~= shape_freedom=0.5
't_start': 0.0,
't_end': 1.0
}
merged_options = default_options
if options is not None:
shape_freedom = options.pop('shape_freedom', None)
if shape_freedom is not None:
# high shape freedom = SD can do what it wants with the shape of the object
# high shape freedom => s_end = 0
# low shape freedom => s_end = 1
# shape freedom is in a "linear" space, while noticeable changes to s_end are typically closer around 0,
# and there is very little perceptible difference as s_end increases above 0.5
# so for shape_freedom = 0.5 we probably want s_end to be 0.2
# -> cube root and subtract from 1.0
merged_options['s_end'] = 1.0 - shape_freedom ** (1. / 3.)
print('converted shape_freedom argument to', merged_options)
merged_options.update(options)
self.options = merged_options
def __repr__(self):
return f"CrossAttentionControlSubstitute:({self.original}->{self.edited} ({self.options})"
def __eq__(self, other):
return type(other) is CrossAttentionControlSubstitute \
and other.original == self.original \
and other.edited == self.edited \
and other.options == self.options
class CrossAttentionControlAppend(CrossAttentionControlledFragment):
def __init__(self, fragment: Fragment):
self.fragment = fragment
def __repr__(self):
return "CrossAttentionControlAppend:",self.fragment
def __eq__(self, other):
return type(other) is CrossAttentionControlAppend \
and other.fragment == self.fragment
class Conjunction():
"""
Storage for one or more Prompts or Blends, each of which is to be separately diffused and then the results merged
by weighted sum in latent space.
"""
def __init__(self, prompts: list, weights: list = None):
# force everything to be a Prompt
#print("making conjunction with", parts)
self.prompts = [x if (type(x) is Prompt
or type(x) is Blend
or type(x) is FlattenedPrompt)
else Prompt(x) for x in prompts]
self.weights = [1.0]*len(self.prompts) if weights is None else list(weights)
if len(self.weights) != len(self.prompts):
raise PromptParser.ParsingException(f"while parsing Conjunction: mismatched parts/weights counts {prompts}, {weights}")
self.type = 'AND'
def __repr__(self):
return f"Conjunction:{self.prompts} | weights {self.weights}"
def __eq__(self, other):
return type(other) is Conjunction \
and other.prompts == self.prompts \
and other.weights == self.weights
class Blend():
"""
Stores a Blend of multiple Prompts. To apply, build feature vectors for each of the child Prompts and then perform a
weighted blend of the feature vectors to produce a single feature vector that is effectively a lerp between the
Prompts.
"""
def __init__(self, prompts: list, weights: list[float], normalize_weights: bool=True):
#print("making Blend with prompts", prompts, "and weights", weights)
if len(prompts) != len(weights):
raise PromptParser.ParsingException(f"while parsing Blend: mismatched prompts/weights counts {prompts}, {weights}")
for c in prompts:
if type(c) is not Prompt and type(c) is not FlattenedPrompt:
raise(PromptParser.ParsingException(f"{type(c)} cannot be added to a Blend, only Prompts or FlattenedPrompts"))
# upcast all lists to Prompt objects
self.prompts = [x if (type(x) is Prompt or type(x) is FlattenedPrompt)
else Prompt(x) for x in prompts]
self.prompts = prompts
self.weights = weights
self.normalize_weights = normalize_weights
def __repr__(self):
return f"Blend:{self.prompts} | weights {' ' if self.normalize_weights else '(non-normalized) '}{self.weights}"
def __eq__(self, other):
return other.__repr__() == self.__repr__()
class PromptParser():
class ParsingException(Exception):
pass
def __init__(self, attention_plus_base=1.1, attention_minus_base=0.9):
self.conjunction, self.prompt = build_parser_syntax(attention_plus_base, attention_minus_base)
def parse_conjunction(self, prompt: str) -> Conjunction:
'''
:param prompt: The prompt string to parse
:return: a Conjunction representing the parsed results.
'''
#print(f"!!parsing '{prompt}'")
if len(prompt.strip()) == 0:
return Conjunction(prompts=[FlattenedPrompt([('', 1.0)])], weights=[1.0])
root = self.conjunction.parse_string(prompt)
#print(f"'{prompt}' parsed to root", root)
#fused = fuse_fragments(parts)
#print("fused to", fused)
return self.flatten(root[0])
def parse_legacy_blend(self, text: str) -> Optional[Blend]:
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=False)
if len(weighted_subprompts) == 1:
return None
strings = [x[0] for x in weighted_subprompts]
weights = [x[1] for x in weighted_subprompts]
parsed_conjunctions = [self.parse_conjunction(x) for x in strings]
flattened_prompts = [x.prompts[0] for x in parsed_conjunctions]
return Blend(prompts=flattened_prompts, weights=weights, normalize_weights=True)
def flatten(self, root: Conjunction) -> Conjunction:
"""
Flattening a Conjunction traverses all of the nested tree-like structures in each of its Prompts or Blends,
producing from each of these walks a linear sequence of Fragment or CrossAttentionControlSubstitute objects
that can be readily tokenized without the need to walk a complex tree structure.
:param root: The Conjunction to flatten.
:return: A Conjunction containing the result of flattening each of the prompts in the passed-in root.
"""
#print("flattening", root)
def fuse_fragments(items):
# print("fusing fragments in ", items)
result = []
for x in items:
if type(x) is CrossAttentionControlSubstitute:
original_fused = fuse_fragments(x.original)
edited_fused = fuse_fragments(x.edited)
result.append(CrossAttentionControlSubstitute(original_fused, edited_fused, options=x.options))
else:
last_weight = result[-1].weight \
if (len(result) > 0 and not issubclass(type(result[-1]), CrossAttentionControlledFragment)) \
else None
this_text = x.text
this_weight = x.weight
if last_weight is not None and last_weight == this_weight:
last_text = result[-1].text
result[-1] = Fragment(last_text + ' ' + this_text, last_weight)
else:
result.append(x)
return result
def flatten_internal(node, weight_scale, results, prefix):
#print(prefix + "flattening", node, "...")
if type(node) is pp.ParseResults:
for x in node:
results = flatten_internal(x, weight_scale, results, prefix+' pr ')
#print(prefix, " ParseResults expanded, results is now", results)
elif type(node) is Attention:
# if node.weight < 1:
# todo: inject a blend when flattening attention with weight <1"
for index,c in enumerate(node.children):
results = flatten_internal(c, weight_scale * node.weight, results, prefix + f" att{index} ")
elif type(node) is Fragment:
results += [Fragment(node.text, node.weight*weight_scale)]
elif type(node) is CrossAttentionControlSubstitute:
original = flatten_internal(node.original, weight_scale, [], prefix + ' CAo ')
edited = flatten_internal(node.edited, weight_scale, [], prefix + ' CAe ')
results += [CrossAttentionControlSubstitute(original, edited, options=node.options)]
elif type(node) is Blend:
flattened_subprompts = []
#print(" flattening blend with prompts", node.prompts, "weights", node.weights)
for prompt in node.prompts:
# prompt is a list
flattened_subprompts = flatten_internal(prompt, weight_scale, flattened_subprompts, prefix+'B ')
results += [Blend(prompts=flattened_subprompts, weights=node.weights, normalize_weights=node.normalize_weights)]
elif type(node) is Prompt:
#print(prefix + "about to flatten Prompt with children", node.children)
flattened_prompt = []
for child in node.children:
flattened_prompt = flatten_internal(child, weight_scale, flattened_prompt, prefix+'P ')
results += [FlattenedPrompt(parts=fuse_fragments(flattened_prompt))]
#print(prefix + "after flattening Prompt, results is", results)
else:
raise PromptParser.ParsingException(f"unhandled node type {type(node)} when flattening {node}")
#print(prefix + "-> after flattening", type(node).__name__, "results is", results)
return results
flattened_parts = []
for part in root.prompts:
flattened_parts += flatten_internal(part, 1.0, [], ' C| ')
#print("flattened to", flattened_parts)
weights = root.weights
return Conjunction(flattened_parts, weights)
def build_parser_syntax(attention_plus_base: float, attention_minus_base: float):
lparen = pp.Literal("(").suppress()
rparen = pp.Literal(")").suppress()
quotes = pp.Literal('"').suppress()
comma = pp.Literal(",").suppress()
# accepts int or float notation, always maps to float
number = pp.pyparsing_common.real | \
pp.Combine(pp.Optional("-")+pp.Word(pp.nums)).set_parse_action(pp.token_map(float))
greedy_word = pp.Word(pp.printables, exclude_chars=string.whitespace).set_name('greedy_word')
attention = pp.Forward()
quoted_fragment = pp.Forward()
parenthesized_fragment = pp.Forward()
cross_attention_substitute = pp.Forward()
prompt_part = pp.Forward()
def make_text_fragment(x):
#print("### making fragment for", x)
if type(x) is str:
return Fragment(x)
elif type(x) is pp.ParseResults or type(x) is list:
#print(f'converting {type(x).__name__} to Fragment')
return Fragment(' '.join([s for s in x]))
else:
raise PromptParser.ParsingException("Cannot make fragment from " + str(x))
def build_escaped_word_parser(escaped_chars_to_ignore: str):
terms = []
for c in escaped_chars_to_ignore:
terms.append(pp.Literal('\\'+c))
terms.append(
#pp.CharsNotIn(string.whitespace + escaped_chars_to_ignore, exact=1)
pp.Word(pp.printables, exclude_chars=string.whitespace + escaped_chars_to_ignore)
)
return pp.Combine(pp.OneOrMore(
pp.MatchFirst(terms)
))
def build_escaped_word_parser_charbychar(escaped_chars_to_ignore: str):
escapes = []
for c in escaped_chars_to_ignore:
escapes.append(pp.Literal('\\'+c))
return pp.Combine(pp.OneOrMore(
pp.MatchFirst(escapes + [pp.CharsNotIn(
string.whitespace + escaped_chars_to_ignore,
exact=1
)])
))
def parse_fragment_str(x, in_quotes: bool=False, in_parens: bool=False):
#print(f"parsing fragment string \"{x}\"")
fragment_string = x[0]
if len(fragment_string.strip()) == 0:
return Fragment('')
if in_quotes:
# escape unescaped quotes
fragment_string = fragment_string.replace('"', '\\"')
#fragment_parser = pp.Group(pp.OneOrMore(attention | cross_attention_substitute | (greedy_word.set_parse_action(make_text_fragment))))
result = pp.Group(pp.MatchFirst([
pp.OneOrMore(prompt_part | quoted_fragment),
pp.Empty().set_parse_action(make_text_fragment) + pp.StringEnd()
])).set_name('rr').set_debug(False).parse_string(fragment_string)
#result = (pp.OneOrMore(attention | unquoted_word) + pp.StringEnd()).parse_string(x[0])
#print("parsed to", result)
return result
quoted_fragment << pp.QuotedString(quote_char='"', esc_char=None, esc_quote='\\"')
quoted_fragment.set_parse_action(lambda x: parse_fragment_str(x, in_quotes=True)).set_name('quoted_fragment')
escaped_quote = pp.Literal('\\"')#.set_parse_action(lambda x: '"')
escaped_lparen = pp.Literal('\\(')#.set_parse_action(lambda x: '(')
escaped_rparen = pp.Literal('\\)')#.set_parse_action(lambda x: ')')
escaped_backslash = pp.Literal('\\\\')#.set_parse_action(lambda x: '"')
def not_ends_with_swap(x):
#print("trying to match:", x)
return not x[0].endswith('.swap')
unquoted_fragment = pp.Combine(pp.OneOrMore(
escaped_rparen | escaped_lparen | escaped_quote | escaped_backslash |
pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"()')))
unquoted_fragment.set_parse_action(make_text_fragment).set_name('unquoted_fragment').set_debug(False)
#print(unquoted_fragment.parse_string("cat.swap(dog)"))
parenthesized_fragment << pp.Or([
(lparen + quoted_fragment.copy().set_parse_action(lambda x: parse_fragment_str(x, in_quotes=True)).set_debug(False) + rparen).set_name('-quoted_paren_internal').set_debug(False),
(lparen + rparen).set_parse_action(lambda x: make_text_fragment('')).set_name('-()').set_debug(False),
(lparen + pp.Combine(pp.OneOrMore(
escaped_quote | escaped_lparen | escaped_rparen | escaped_backslash |
pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"()') |
pp.Word(string.whitespace)
)).set_name('--combined').set_parse_action(lambda x: parse_fragment_str(x, in_parens=True)).set_debug(False) + rparen)]).set_name('-unquoted_paren_internal').set_debug(False)
parenthesized_fragment.set_name('parenthesized_fragment').set_debug(False)
debug_attention = False
# attention control of the form (phrase)+ / (phrase)+ / (phrase)<weight>
# phrase can be multiple words, can have multiple +/- signs to increase the effect or type a floating point or integer weight
attention_with_parens = pp.Forward()
attention_without_parens = pp.Forward()
attention_with_parens_foot = (number | pp.Word('+') | pp.Word('-'))\
.set_name("attention_foot")\
.set_debug(False)
attention_with_parens <<= pp.Group(
lparen +
pp.ZeroOrMore(quoted_fragment | attention_with_parens | parenthesized_fragment | cross_attention_substitute | attention_without_parens |
(pp.Empty() + build_escaped_word_parser_charbychar('()')).set_name('undecorated_word').set_debug(debug_attention)#.set_parse_action(lambda t: t[0])
)
+ rparen + attention_with_parens_foot)
attention_with_parens.set_name('attention_with_parens').set_debug(debug_attention)
attention_without_parens_foot = pp.Or(pp.Word('+') | pp.Word('-')).set_name('attention_without_parens_foots')
attention_without_parens <<= pp.Group(
(quoted_fragment.copy().set_name('attention_quoted_fragment_without_parens').set_debug(debug_attention) + attention_without_parens_foot) |
pp.Combine(build_escaped_word_parser_charbychar('()+-')).set_name('attention_word_without_parens').set_debug(debug_attention)#.set_parse_action(lambda x: print('escapéd', x))
+ attention_without_parens_foot)#.leave_whitespace()
attention_without_parens.set_name('attention_without_parens').set_debug(debug_attention)
attention << pp.MatchFirst([attention_with_parens,
attention_without_parens
])
attention.set_name('attention')
def make_attention(x):
#print("entered make_attention with", x)
children = x[0][:-1]
weight_raw = x[0][-1]
weight = 1.0
if type(weight_raw) is float or type(weight_raw) is int:
weight = weight_raw
elif type(weight_raw) is str:
base = attention_plus_base if weight_raw[0] == '+' else attention_minus_base
weight = pow(base, len(weight_raw))
#print("making Attention from", children, "with weight", weight)
return Attention(weight=weight, children=[(Fragment(x) if type(x) is str else x) for x in children])
attention_with_parens.set_parse_action(make_attention)
attention_without_parens.set_parse_action(make_attention)
#print("parsing test:", attention_with_parens.parse_string("mountain (man)1.1"))
# cross-attention control
empty_string = ((lparen + rparen) |
pp.Literal('""').suppress() |
(lparen + pp.Literal('""').suppress() + rparen)
).set_parse_action(lambda x: Fragment(""))
empty_string.set_name('empty_string')
# cross attention control
debug_cross_attention_control = False
original_fragment = pp.Or([empty_string.set_debug(debug_cross_attention_control),
quoted_fragment.set_debug(debug_cross_attention_control),
parenthesized_fragment.set_debug(debug_cross_attention_control),
pp.Word(pp.printables, exclude_chars=string.whitespace + '.').set_parse_action(make_text_fragment) + pp.FollowedBy(".swap")
])
# support keyword=number arguments
cross_attention_option_keyword = pp.Or([pp.Keyword("s_start"), pp.Keyword("s_end"), pp.Keyword("t_start"), pp.Keyword("t_end"), pp.Keyword("shape_freedom")])
cross_attention_option = pp.Group(cross_attention_option_keyword + pp.Literal("=").suppress() + number)
edited_fragment = pp.MatchFirst([
lparen +
(quoted_fragment |
pp.Group(pp.OneOrMore(pp.Word(pp.printables, exclude_chars=string.whitespace + ',').set_parse_action(make_text_fragment)))
) +
pp.Dict(pp.OneOrMore(comma + cross_attention_option)) +
rparen,
parenthesized_fragment
])
cross_attention_substitute << original_fragment + pp.Literal(".swap").suppress() + edited_fragment
original_fragment.set_name('original_fragment').set_debug(debug_cross_attention_control)
edited_fragment.set_name('edited_fragment').set_debug(debug_cross_attention_control)
cross_attention_substitute.set_name('cross_attention_substitute').set_debug(debug_cross_attention_control)
def make_cross_attention_substitute(x):
#print("making cacs for", x[0], "->", x[1], "with options", x.as_dict())
#if len(x>2):
cacs = CrossAttentionControlSubstitute(x[0], x[1], options=x.as_dict())
#print("made", cacs)
return cacs
cross_attention_substitute.set_parse_action(make_cross_attention_substitute)
# simple fragments of text
# use Or to match the longest
prompt_part << pp.MatchFirst([
cross_attention_substitute,
attention,
unquoted_fragment,
lparen + unquoted_fragment + rparen # matches case where user has +(term) and just deletes the +
])
prompt_part.set_debug(False)
prompt_part.set_name("prompt_part")
empty = (
(lparen + pp.ZeroOrMore(pp.Word(string.whitespace)) + rparen) |
(quotes + pp.ZeroOrMore(pp.Word(string.whitespace)) + quotes)).set_debug(False).set_name('empty')
# root prompt definition
prompt = (pp.OneOrMore(pp.Or([prompt_part, quoted_fragment, empty])) + pp.StringEnd()) \
.set_parse_action(lambda x: Prompt(x))
#print("parsing test:", prompt.parse_string("spaced eyes--"))
#print("parsing test:", prompt.parse_string("eyes--"))
# weighted blend of prompts
# ("promptA", "promptB").blend(a, b) where "promptA" and "promptB" are valid prompts and a and b are float or
# int weights.
# can specify more terms eg ("promptA", "promptB", "promptC").blend(a,b,c)
def make_prompt_from_quoted_string(x):
#print(' got quoted prompt', x)
x_unquoted = x[0][1:-1]
if len(x_unquoted.strip()) == 0:
# print(' b : just an empty string')
return Prompt([Fragment('')])
# print(' b parsing ', c_unquoted)
x_parsed = prompt.parse_string(x_unquoted)
#print(" quoted prompt was parsed to", type(x_parsed),":", x_parsed)
return x_parsed[0]
quoted_prompt = pp.dbl_quoted_string.set_parse_action(make_prompt_from_quoted_string)
quoted_prompt.set_name('quoted_prompt')
debug_blend=False
blend_terms = pp.delimited_list(quoted_prompt).set_name('blend_terms').set_debug(debug_blend)
blend_weights = (pp.delimited_list(number) + pp.Optional(pp.Char(",").suppress() + "no_normalize")).set_name('blend_weights').set_debug(debug_blend)
blend = pp.Group(lparen + pp.Group(blend_terms) + rparen
+ pp.Literal(".blend").suppress()
+ lparen + pp.Group(blend_weights) + rparen).set_name('blend')
blend.set_debug(debug_blend)
def make_blend(x):
prompts = x[0][0]
weights = x[0][1]
normalize = True
if weights[-1] == 'no_normalize':
normalize = False
weights = weights[:-1]
return Blend(prompts=prompts, weights=weights, normalize_weights=normalize)
blend.set_parse_action(make_blend)
conjunction_terms = blend_terms.copy().set_name('conjunction_terms')
conjunction_weights = blend_weights.copy().set_name('conjunction_weights')
conjunction_with_parens_and_quotes = pp.Group(lparen + pp.Group(conjunction_terms) + rparen
+ pp.Literal(".and").suppress()
+ lparen + pp.Optional(pp.Group(conjunction_weights)) + rparen).set_name('conjunction')
def make_conjunction(x):
parts_raw = x[0][0]
weights = x[0][1] if len(x[0])>1 else [1.0]*len(parts_raw)
parts = [part for part in parts_raw]
return Conjunction(parts, weights)
conjunction_with_parens_and_quotes.set_parse_action(make_conjunction)
implicit_conjunction = pp.OneOrMore(blend | prompt).set_name('implicit_conjunction')
implicit_conjunction.set_parse_action(lambda x: Conjunction(x))
conjunction = conjunction_with_parens_and_quotes | implicit_conjunction
conjunction.set_debug(False)
# top-level is a conjunction of one or more blends or prompts
return conjunction, prompt
def split_weighted_subprompts(text, skip_normalize=False)->list:
"""
Legacy blend parsing.
grabs all text up to the first occurrence of ':'
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
if ':' has no value defined, defaults to 1.0
repeats until no text remaining
"""
prompt_parser = re.compile("""
(?P<prompt> # capture group for 'prompt'
(?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:'
) # end 'prompt'
(?: # non-capture group
:+ # match one or more ':' characters
(?P<weight> # capture group for 'weight'
-?\d+(?:\.\d+)? # match positive or negative integer or decimal number
)? # end weight capture group, make optional
\s* # strip spaces after weight
| # OR
$ # else, if no ':' then match end of line
) # end non-capture group
""", re.VERBOSE)
parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float(
match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)]
if skip_normalize:
return parsed_prompts
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
if weight_sum == 0:
print(
"Warning: Subprompt weights add up to zero. Discarding and using even weights instead.")
equal_weight = 1 / max(len(parsed_prompts), 1)
return [(x[0], equal_weight) for x in parsed_prompts]
return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
# shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' '
def log_tokenization(text, model, log=False, weight=1):
if not log:
return
tokens = model.cond_stage_model.tokenizer._tokenize(text)
tokenized = ""
discarded = ""
usedTokens = 0
totalTokens = len(tokens)
for i in range(0, totalTokens):
token = tokens[i].replace('</w>', 'x` ')
# alternate color
s = (usedTokens % 6) + 1
if i < model.cond_stage_model.max_length:
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
usedTokens += 1
else: # over max token length
discarded = discarded + f"\x1b[0;3{s};40m{token}"
print(f"\n>> Tokens ({usedTokens}), Weight ({weight:.2f}):\n{tokenized}\x1b[0m")
if discarded != "":
print(
f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m"
)

View File

@ -0,0 +1,238 @@
from enum import Enum
import torch
# adapted from bloc97's CrossAttentionControl colab
# https://github.com/bloc97/CrossAttentionControl
class CrossAttentionControl:
class Arguments:
def __init__(self, edited_conditioning: torch.Tensor, edit_opcodes: list[tuple], edit_options: dict):
"""
:param edited_conditioning: if doing cross-attention control, the edited conditioning [1 x 77 x 768]
:param edit_opcodes: if doing cross-attention control, a list of difflib.SequenceMatcher-like opcodes describing how to map original conditioning tokens to edited conditioning tokens (only the 'equal' opcode is required)
:param edit_options: if doing cross-attention control, per-edit options. there should be 1 item in edit_options for each item in edit_opcodes.
"""
# todo: rewrite this to take embedding fragments rather than a single edited_conditioning vector
self.edited_conditioning = edited_conditioning
self.edit_opcodes = edit_opcodes
if edited_conditioning is not None:
assert len(edit_opcodes) == len(edit_options), \
"there must be 1 edit_options dict for each edit_opcodes tuple"
non_none_edit_options = [x for x in edit_options if x is not None]
assert len(non_none_edit_options)>0, "missing edit_options"
if len(non_none_edit_options)>1:
print('warning: cross-attention control options are not working properly for >1 edit')
self.edit_options = non_none_edit_options[0]
class Context:
def __init__(self, arguments: 'CrossAttentionControl.Arguments', step_count: int):
"""
:param arguments: Arguments for the cross-attention control process
:param step_count: The absolute total number of steps of diffusion (for img2img this is likely larger than the number of steps that will actually run)
"""
self.arguments = arguments
self.step_count = step_count
@classmethod
def remove_cross_attention_control(cls, model):
cls.remove_attention_function(model)
@classmethod
def setup_cross_attention_control(cls, model,
cross_attention_control_args: Arguments
):
"""
Inject attention parameters and functions into the passed in model to enable cross attention editing.
:param model: The unet model to inject into.
:param cross_attention_control_args: Arugments passeed to the CrossAttentionControl implementations
:return: None
"""
# adapted from init_attention_edit
device = cross_attention_control_args.edited_conditioning.device
# urgh. should this be hardcoded?
max_length = 77
# mask=1 means use base prompt attention, mask=0 means use edited prompt attention
mask = torch.zeros(max_length)
indices_target = torch.arange(max_length, dtype=torch.long)
indices = torch.zeros(max_length, dtype=torch.long)
for name, a0, a1, b0, b1 in cross_attention_control_args.edit_opcodes:
if b0 < max_length:
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
# these tokens have not been edited
indices[b0:b1] = indices_target[a0:a1]
mask[b0:b1] = 1
for m in cls.get_attention_modules(model, cls.CrossAttentionType.SELF):
m.last_attn_slice_mask = None
m.last_attn_slice_indices = None
for m in cls.get_attention_modules(model, cls.CrossAttentionType.TOKENS):
m.last_attn_slice_mask = mask.to(device)
m.last_attn_slice_indices = indices.to(device)
cls.inject_attention_function(model)
class CrossAttentionType(Enum):
SELF = 1
TOKENS = 2
@classmethod
def get_active_cross_attention_control_types_for_step(cls, context: 'CrossAttentionControl.Context', percent_through:float=None)\
-> list['CrossAttentionControl.CrossAttentionType']:
"""
Should cross-attention control be applied on the given step?
:param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0.
:return: A list of attention types that cross-attention control should be performed for on the given step. May be [].
"""
if percent_through is None:
return [cls.CrossAttentionType.SELF, cls.CrossAttentionType.TOKENS]
opts = context.arguments.edit_options
to_control = []
if opts['s_start'] <= percent_through and percent_through < opts['s_end']:
to_control.append(cls.CrossAttentionType.SELF)
if opts['t_start'] <= percent_through and percent_through < opts['t_end']:
to_control.append(cls.CrossAttentionType.TOKENS)
return to_control
@classmethod
def get_attention_modules(cls, model, which: CrossAttentionType):
which_attn = "attn1" if which is cls.CrossAttentionType.SELF else "attn2"
return [module for name, module in model.named_modules() if
type(module).__name__ == "CrossAttention" and which_attn in name]
@classmethod
def clear_requests(cls, model):
self_attention_modules = cls.get_attention_modules(model, cls.CrossAttentionType.SELF)
tokens_attention_modules = cls.get_attention_modules(model, cls.CrossAttentionType.TOKENS)
for m in self_attention_modules+tokens_attention_modules:
m.save_last_attn_slice = False
m.use_last_attn_slice = False
@classmethod
def request_save_attention_maps(cls, model, cross_attention_type: CrossAttentionType):
modules = cls.get_attention_modules(model, cross_attention_type)
for m in modules:
# clear out the saved slice in case the outermost dim changes
m.last_attn_slice = None
m.save_last_attn_slice = True
@classmethod
def request_apply_saved_attention_maps(cls, model, cross_attention_type: CrossAttentionType):
modules = cls.get_attention_modules(model, cross_attention_type)
for m in modules:
m.use_last_attn_slice = True
@classmethod
def inject_attention_function(cls, unet):
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
def attention_slice_wrangler(self, attention_scores, suggested_attention_slice, dim, offset, slice_size):
#print("in wrangler with suggested_attention_slice shape", suggested_attention_slice.shape, "dim", dim)
attn_slice = suggested_attention_slice
if dim is not None:
start = offset
end = start+slice_size
#print(f"in wrangler, sliced dim {dim} {start}-{end}, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}")
#else:
# print(f"in wrangler, whole, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}")
if self.use_last_attn_slice:
this_attn_slice = attn_slice
if self.last_attn_slice_mask is not None:
# indices and mask operate on dim=2, no need to slice
base_attn_slice_full = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices)
base_attn_slice_mask = self.last_attn_slice_mask
if dim is None:
base_attn_slice = base_attn_slice_full
#print("using whole base slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape)
elif dim == 0:
base_attn_slice = base_attn_slice_full[start:end]
#print("using base dim 0 slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape)
elif dim == 1:
base_attn_slice = base_attn_slice_full[:, start:end]
#print("using base dim 1 slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape)
attn_slice = this_attn_slice * (1 - base_attn_slice_mask) + \
base_attn_slice * base_attn_slice_mask
else:
if dim is None:
attn_slice = self.last_attn_slice
#print("took whole slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
elif dim == 0:
attn_slice = self.last_attn_slice[start:end]
#print("took dim 0 slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
elif dim == 1:
attn_slice = self.last_attn_slice[:, start:end]
#print("took dim 1 slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
if self.save_last_attn_slice:
if dim is None:
self.last_attn_slice = attn_slice
elif dim == 0:
# dynamically grow last_attn_slice if needed
if self.last_attn_slice is None:
self.last_attn_slice = attn_slice
#print("no last_attn_slice: shape now", self.last_attn_slice.shape)
elif self.last_attn_slice.shape[0] == start:
self.last_attn_slice = torch.cat([self.last_attn_slice, attn_slice], dim=0)
assert(self.last_attn_slice.shape[0] == end)
#print("last_attn_slice too small, appended dim 0 shape", attn_slice.shape, ", shape now", self.last_attn_slice.shape)
else:
# no need to grow
self.last_attn_slice[start:end] = attn_slice
#print("last_attn_slice shape is fine, setting dim 0 shape", attn_slice.shape, ", shape now", self.last_attn_slice.shape)
elif dim == 1:
# dynamically grow last_attn_slice if needed
if self.last_attn_slice is None:
self.last_attn_slice = attn_slice
elif self.last_attn_slice.shape[1] == start:
self.last_attn_slice = torch.cat([self.last_attn_slice, attn_slice], dim=1)
assert(self.last_attn_slice.shape[1] == end)
else:
# no need to grow
self.last_attn_slice[:, start:end] = attn_slice
if self.use_last_attn_weights and self.last_attn_slice_weights is not None:
if dim is None:
weights = self.last_attn_slice_weights
elif dim == 0:
weights = self.last_attn_slice_weights[start:end]
elif dim == 1:
weights = self.last_attn_slice_weights[:, start:end]
attn_slice = attn_slice * weights
return attn_slice
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == "CrossAttention":
module.last_attn_slice = None
module.last_attn_slice_indices = None
module.last_attn_slice_mask = None
module.use_last_attn_weights = False
module.use_last_attn_slice = False
module.save_last_attn_slice = False
module.set_attention_slice_wrangler(attention_slice_wrangler)
@classmethod
def remove_attention_function(cls, unet):
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == "CrossAttention":
module.set_attention_slice_wrangler(None)

View File

@ -1,10 +1,7 @@
"""SAMPLING ONLY."""
import torch
import numpy as np
from tqdm import tqdm
from functools import partial
from ldm.invoke.devices import choose_torch_device
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ldm.models.diffusion.sampler import Sampler
from ldm.modules.diffusionmodules.util import noise_like
@ -12,6 +9,21 @@ class DDIMSampler(Sampler):
def __init__(self, model, schedule='linear', device=None, **kwargs):
super().__init__(model,schedule,model.num_timesteps,device)
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model,
model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond))
def prepare_to_sample(self, t_enc, **kwargs):
super().prepare_to_sample(t_enc, **kwargs)
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc)
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = all_timesteps_count)
else:
self.invokeai_diffuser.remove_cross_attention_control()
# This is the central routine
@torch.no_grad()
def p_sample(
@ -29,6 +41,7 @@ class DDIMSampler(Sampler):
corrector_kwargs=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
step_count:int=1000, # total number of steps
**kwargs,
):
b, *_, device = *x.shape, x.device
@ -37,15 +50,14 @@ class DDIMSampler(Sampler):
unconditional_conditioning is None
or unconditional_guidance_scale == 1.0
):
# damian0815 would like to know when/if this code path is used
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (
e_t - e_t_uncond
)
step_index = step_count-(index+1)
e_t = self.invokeai_diffuser.do_diffusion_step(x, t,
unconditional_conditioning, c,
unconditional_guidance_scale,
step_index=step_index)
if score_corrector is not None:
assert self.model.parameterization == 'eps'

View File

@ -820,21 +820,21 @@ class LatentDiffusion(DDPM):
)
return self.scale_factor * z
def get_learned_conditioning(self, c):
def get_learned_conditioning(self, c, **kwargs):
if self.cond_stage_forward is None:
if hasattr(self.cond_stage_model, 'encode') and callable(
self.cond_stage_model.encode
):
c = self.cond_stage_model.encode(
c, embedding_manager=self.embedding_manager
c, embedding_manager=self.embedding_manager, **kwargs
)
if isinstance(c, DiagonalGaussianDistribution):
c = c.mode()
else:
c = self.cond_stage_model(c)
c = self.cond_stage_model(c, **kwargs)
else:
assert hasattr(self.cond_stage_model, self.cond_stage_forward)
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c, **kwargs)
return c
def meshgrid(self, h, w):

View File

@ -1,16 +1,12 @@
"""wrapper around part of Katherine Crowson's k-diffusion library, making it call compatible with other Samplers"""
import k_diffusion as K
import torch
import torch.nn as nn
from ldm.invoke.devices import choose_torch_device
from ldm.models.diffusion.sampler import Sampler
from ldm.util import rand_perlin_2d
from ldm.modules.diffusionmodules.util import (
make_ddim_sampling_parameters,
make_ddim_timesteps,
noise_like,
extract_into_tensor,
)
from torch import nn
from .sampler import Sampler
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
if threshold <= 0.0:
@ -33,12 +29,24 @@ class CFGDenoiser(nn.Module):
self.threshold = threshold
self.warmup_max = warmup
self.warmup = max(warmup / 10, 1)
self.invokeai_diffuser = InvokeAIDiffuserComponent(model,
model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond))
def prepare_to_sample(self, t_enc, **kwargs):
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = t_enc)
else:
self.invokeai_diffuser.remove_cross_attention_control()
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale)
# apply threshold
if self.warmup < self.warmup_max:
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
self.warmup += 1
@ -46,7 +54,8 @@ class CFGDenoiser(nn.Module):
thresh = self.threshold
if thresh > self.threshold:
thresh = self.threshold
return cfg_apply_threshold(uncond + (cond - uncond) * cond_scale, thresh)
return cfg_apply_threshold(next_x, thresh)
class KSampler(Sampler):
@ -61,16 +70,6 @@ class KSampler(Sampler):
self.ds = None
self.s_in = None
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(
x_in, sigma_in, cond=cond_in
).chunk(2)
return uncond + (cond - uncond) * cond_scale
def make_schedule(
self,
ddim_num_steps,
@ -118,6 +117,7 @@ class KSampler(Sampler):
use_original_steps=False,
init_latent = None,
mask = None,
**kwargs
):
samples,_ = self.sample(
batch_size = 1,
@ -129,7 +129,8 @@ class KSampler(Sampler):
unconditional_conditioning = unconditional_conditioning,
img_callback = img_callback,
x0 = init_latent,
mask = mask
mask = mask,
**kwargs
)
return samples
@ -163,6 +164,7 @@ class KSampler(Sampler):
log_every_t=100,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
extra_conditioning_info=None,
threshold = 0,
perlin = 0,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
@ -181,7 +183,6 @@ class KSampler(Sampler):
)
# sigmas are set up in make_schedule - we take the last steps items
total_steps = len(self.sigmas)
sigmas = self.sigmas[-S-1:]
# x_T is variation noise. When an init image is provided (in x0) we need to add
@ -195,19 +196,21 @@ class KSampler(Sampler):
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0]
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
model_wrap_cfg.prepare_to_sample(S, extra_conditioning_info=extra_conditioning_info)
extra_args = {
'cond': conditioning,
'uncond': unconditional_conditioning,
'cond_scale': unconditional_guidance_scale,
}
print(f'>> Sampling with k_{self.schedule} starting at step {len(self.sigmas)-S-1} of {len(self.sigmas)-1} ({S} new sampling steps)')
return (
sampling_result = (
K.sampling.__dict__[f'sample_{self.schedule}'](
model_wrap_cfg, x, sigmas, extra_args=extra_args,
callback=route_callback
),
None,
)
return sampling_result
# this code will support inpainting if and when ksampler API modified or
# a workaround is found.
@ -220,6 +223,7 @@ class KSampler(Sampler):
index,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
extra_conditioning_info=None,
**kwargs,
):
if self.model_wrap is None:
@ -245,6 +249,7 @@ class KSampler(Sampler):
# so the actual formula for indexing into sigmas:
# sigma_index = (steps-index)
s_index = t_enc - index - 1
self.model_wrap.prepare_to_sample(s_index, extra_conditioning_info=extra_conditioning_info)
img = K.sampling.__dict__[f'_{self.schedule}'](
self.model_wrap,
img,
@ -269,7 +274,7 @@ class KSampler(Sampler):
else:
return x
def prepare_to_sample(self,t_enc):
def prepare_to_sample(self,t_enc,**kwargs):
self.t_enc = t_enc
self.model_wrap = None
self.ds = None

View File

@ -5,6 +5,7 @@ import numpy as np
from tqdm import tqdm
from functools import partial
from ldm.invoke.devices import choose_torch_device
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ldm.models.diffusion.sampler import Sampler
from ldm.modules.diffusionmodules.util import noise_like
@ -13,6 +14,21 @@ class PLMSSampler(Sampler):
def __init__(self, model, schedule='linear', device=None, **kwargs):
super().__init__(model,schedule,model.num_timesteps, device)
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model,
model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond))
def prepare_to_sample(self, t_enc, **kwargs):
super().prepare_to_sample(t_enc, **kwargs)
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc)
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = all_timesteps_count)
else:
self.invokeai_diffuser.remove_cross_attention_control()
# this is the essential routine
@torch.no_grad()
def p_sample(
@ -32,6 +48,7 @@ class PLMSSampler(Sampler):
unconditional_conditioning=None,
old_eps=[],
t_next=None,
step_count:int=1000, # total number of steps
**kwargs,
):
b, *_, device = *x.shape, x.device
@ -41,17 +58,15 @@ class PLMSSampler(Sampler):
unconditional_conditioning is None
or unconditional_guidance_scale == 1.0
):
# damian0815 would like to know when/if this code path is used
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(
x_in, t_in, c_in
).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (
e_t - e_t_uncond
)
# step_index counts in the opposite direction to index
step_index = step_count-(index+1)
e_t = self.invokeai_diffuser.do_diffusion_step(x, t,
unconditional_conditioning, c,
unconditional_guidance_scale,
step_index=step_index)
if score_corrector is not None:
assert self.model.parameterization == 'eps'

View File

@ -4,6 +4,8 @@ ldm.models.diffusion.sampler
Base class for ldm.models.diffusion.ddim, ldm.models.diffusion.ksampler, etc
'''
from math import ceil
import torch
import numpy as np
from tqdm import tqdm
@ -190,6 +192,7 @@ class Sampler(object):
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
steps=S,
**kwargs
)
return samples, intermediates
@ -214,6 +217,7 @@ class Sampler(object):
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
steps=None,
**kwargs
):
b = shape[0]
time_range = (
@ -231,7 +235,7 @@ class Sampler(object):
dynamic_ncols=True,
)
old_eps = []
self.prepare_to_sample(t_enc=total_steps)
self.prepare_to_sample(t_enc=total_steps,all_timesteps_count=steps,**kwargs)
img = self.get_initial_image(x_T,shape,total_steps)
# probably don't need this at all
@ -274,6 +278,7 @@ class Sampler(object):
unconditional_conditioning=unconditional_conditioning,
old_eps=old_eps,
t_next=ts_next,
step_count=steps
)
img, pred_x0, e_t = outs
@ -305,6 +310,8 @@ class Sampler(object):
use_original_steps=False,
init_latent = None,
mask = None,
all_timesteps_count = None,
**kwargs
):
timesteps = (
@ -321,7 +328,7 @@ class Sampler(object):
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
x_dec = x_latent
x0 = init_latent
self.prepare_to_sample(t_enc=total_steps)
self.prepare_to_sample(t_enc=total_steps, all_timesteps_count=all_timesteps_count, **kwargs)
for i, step in enumerate(iterator):
index = total_steps - i - 1
@ -353,6 +360,7 @@ class Sampler(object):
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
t_next = ts_next,
step_count=len(self.ddim_timesteps)
)
x_dec, pred_x0, e_t = outs
@ -411,3 +419,4 @@ class Sampler(object):
return self.model.inner_model.q_sample(x0,ts)
'''
return self.model.q_sample(x0,ts)

View File

@ -0,0 +1,176 @@
from math import ceil
from typing import Callable, Optional
import torch
from ldm.models.diffusion.cross_attention_control import CrossAttentionControl
class InvokeAIDiffuserComponent:
'''
The aim of this component is to provide a single place for code that can be applied identically to
all InvokeAI diffusion procedures.
At the moment it includes the following features:
* Cross Attention Control ("prompt2prompt")
'''
class ExtraConditioningInfo:
def __init__(self, cross_attention_control_args: Optional[CrossAttentionControl.Arguments]):
self.cross_attention_control_args = cross_attention_control_args
@property
def wants_cross_attention_control(self):
return self.cross_attention_control_args is not None
def __init__(self, model, model_forward_callback:
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
):
"""
:param model: the unet model to pass through to cross attention control
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
"""
self.model = model
self.model_forward_callback = model_forward_callback
def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int):
self.conditioning = conditioning
self.cross_attention_control_context = CrossAttentionControl.Context(
arguments=self.conditioning.cross_attention_control_args,
step_count=step_count
)
CrossAttentionControl.setup_cross_attention_control(self.model,
cross_attention_control_args=self.conditioning.cross_attention_control_args
)
#todo: refactor edited_conditioning, edit_opcodes, edit_options into a struct
#todo: apply edit_options using step_count
def remove_cross_attention_control(self):
self.conditioning = None
self.cross_attention_control_context = None
CrossAttentionControl.remove_cross_attention_control(self.model)
def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor,
unconditioning: torch.Tensor, conditioning: torch.Tensor,
unconditional_guidance_scale: float,
step_index: int=None
):
"""
:param x: Current latents
:param sigma: aka t, passed to the internal model to control how much denoising will occur
:param unconditioning: [B x 77 x 768] embeddings for unconditioned output
:param conditioning: [B x 77 x 768] embeddings for conditioned output
:param unconditional_guidance_scale: aka CFG scale, controls how much effect the conditioning tensor has
:param step_index: Counts upwards from 0 to (step_count-1) (as passed to setup_cross_attention_control, if using). May be called multiple times for a single step, therefore do not assume that its value will monotically increase.
:return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning.
"""
CrossAttentionControl.clear_requests(self.model)
cross_attention_control_types_to_do = []
if self.cross_attention_control_context is not None:
if step_index is not None:
# percent_through will never reach 1.0 (but this is intended)
percent_through = float(step_index) / float(self.cross_attention_control_context.step_count)
else:
# find the current sigma in the sigma sequence
# todo: this doesn't work with k_dpm_2 because the sigma used jumps around in the sequence
sigma_index = torch.nonzero(self.model.sigmas <= sigma)[-1]
# flip because sigmas[0] is for the fully denoised image
# percent_through must be <1
percent_through = 1.0 - float(sigma_index.item() + 1) / float(self.model.sigmas.shape[0])
#print('estimated percent_through', percent_through, 'from sigma', sigma.item())
cross_attention_control_types_to_do = CrossAttentionControl.get_active_cross_attention_control_types_for_step(self.cross_attention_control_context, percent_through)
if len(cross_attention_control_types_to_do)==0:
#print('not doing cross attention control')
# faster batched path
x_twice = torch.cat([x]*2)
sigma_twice = torch.cat([sigma]*2)
both_conditionings = torch.cat([unconditioning, conditioning])
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(x_twice, sigma_twice, both_conditionings).chunk(2)
else:
#print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
# slower non-batched path (20% slower on mac MPS)
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of
# unconditioned_next_x causes attention maps to *also* be saved for the unconditioned_next_x.
# This messes app their application later, due to mismatched shape of dim 0 (seems to be 16 for batched vs. 8)
# (For the batched invocation the `wrangler` function gets attention tensor with shape[0]=16,
# representing batched uncond + cond, but then when it comes to applying the saved attention, the
# wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.)
# todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well.
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
# process x using the original prompt, saving the attention maps
for type in cross_attention_control_types_to_do:
CrossAttentionControl.request_save_attention_maps(self.model, type)
_ = self.model_forward_callback(x, sigma, conditioning)
CrossAttentionControl.clear_requests(self.model)
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
for type in cross_attention_control_types_to_do:
CrossAttentionControl.request_apply_saved_attention_maps(self.model, type)
edited_conditioning = self.conditioning.cross_attention_control_args.edited_conditioning
conditioned_next_x = self.model_forward_callback(x, sigma, edited_conditioning)
CrossAttentionControl.clear_requests(self.model)
# to scale how much effect conditioning has, calculate the changes it does and then scale that
scaled_delta = (conditioned_next_x - unconditioned_next_x) * unconditional_guidance_scale
combined_next_x = unconditioned_next_x + scaled_delta
return combined_next_x
# todo: make this work
@classmethod
def apply_conjunction(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale):
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2) # aka sigmas
deltas = None
uncond_latents = None
weighted_cond_list = c_or_weighted_c_list if type(c_or_weighted_c_list) is list else [(c_or_weighted_c_list, 1)]
# below is fugly omg
num_actual_conditionings = len(c_or_weighted_c_list)
conditionings = [uc] + [c for c,weight in weighted_cond_list]
weights = [1] + [weight for c,weight in weighted_cond_list]
chunk_count = ceil(len(conditionings)/2)
deltas = None
for chunk_index in range(chunk_count):
offset = chunk_index*2
chunk_size = min(2, len(conditionings)-offset)
if chunk_size == 1:
c_in = conditionings[offset]
latents_a = forward_func(x_in[:-1], t_in[:-1], c_in)
latents_b = None
else:
c_in = torch.cat(conditionings[offset:offset+2])
latents_a, latents_b = forward_func(x_in, t_in, c_in).chunk(2)
# first chunk is guaranteed to be 2 entries: uncond_latents + first conditioining
if chunk_index == 0:
uncond_latents = latents_a
deltas = latents_b - uncond_latents
else:
deltas = torch.cat((deltas, latents_a - uncond_latents))
if latents_b is not None:
deltas = torch.cat((deltas, latents_b - uncond_latents))
# merge the weighted deltas together into a single merged delta
per_delta_weights = torch.tensor(weights[1:], dtype=deltas.dtype, device=deltas.device)
normalize = False
if normalize:
per_delta_weights /= torch.sum(per_delta_weights)
reshaped_weights = per_delta_weights.reshape(per_delta_weights.shape + (1, 1, 1))
deltas_merged = torch.sum(deltas * reshaped_weights, dim=0, keepdim=True)
# old_return_value = super().forward(x, sigma, uncond, cond, cond_scale)
# assert(0 == len(torch.nonzero(old_return_value - (uncond_latents + deltas_merged * cond_scale))))
return uncond_latents + deltas_merged * global_guidance_scale

View File

@ -1,5 +1,7 @@
from inspect import isfunction
import math
from typing import Callable
import torch
import torch.nn.functional as F
from torch import nn, einsum
@ -150,6 +152,7 @@ class SpatialSelfAttention(nn.Module):
return x+h_
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
super().__init__()
@ -170,46 +173,73 @@ class CrossAttention(nn.Module):
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
def einsum_op_compvis(self, q, k, v):
s = einsum('b i d, b j d -> b i j', q, k)
s = s.softmax(dim=-1, dtype=s.dtype)
return einsum('b i j, b j d -> b i d', s, v)
self.attention_slice_wrangler = None
def einsum_op_slice_0(self, q, k, v, slice_size):
def set_attention_slice_wrangler(self, wrangler:Callable[[nn.Module, torch.Tensor, torch.Tensor, int, int, int], torch.Tensor]):
'''
Set custom attention calculator to be called when attention is calculated
:param wrangler: Callback, with args (self, attention_scores, suggested_attention_slice, dim, offset, slice_size),
which returns either the suggested_attention_slice or an adjusted equivalent.
self is the current CrossAttention module for which the callback is being invoked.
attention_scores are the scores for attention
suggested_attention_slice is a softmax(dim=-1) over attention_scores
dim is -1 if the call is non-sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
If dim is >= 0, offset and slice_size specify the slice start and length.
Pass None to use the default attention calculation.
:return:
'''
self.attention_slice_wrangler = wrangler
def einsum_lowest_level(self, q, k, v, dim, offset, slice_size):
# calculate attention scores
attention_scores = einsum('b i d, b j d -> b i j', q, k)
# calculate attenion slice by taking the best scores for each latent pixel
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
if self.attention_slice_wrangler is not None:
attention_slice = self.attention_slice_wrangler(self, attention_scores, default_attention_slice, dim, offset, slice_size)
else:
attention_slice = default_attention_slice
return einsum('b i j, b j d -> b i d', attention_slice, v)
def einsum_op_slice_dim0(self, q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[0], slice_size):
end = i + slice_size
r[i:end] = self.einsum_op_compvis(q[i:end], k[i:end], v[i:end])
r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size)
return r
def einsum_op_slice_1(self, q, k, v, slice_size):
def einsum_op_slice_dim1(self, q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
r[:, i:end] = self.einsum_op_compvis(q[:, i:end], k, v)
r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size)
return r
def einsum_op_mps_v1(self, q, k, v):
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
return self.einsum_op_compvis(q, k, v)
return self.einsum_lowest_level(q, k, v, None, None, None)
else:
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
return self.einsum_op_slice_1(q, k, v, slice_size)
return self.einsum_op_slice_dim1(q, k, v, slice_size)
def einsum_op_mps_v2(self, q, k, v):
if self.mem_total_gb > 8 and q.shape[1] <= 4096:
return self.einsum_op_compvis(q, k, v)
return self.einsum_lowest_level(q, k, v, None, None, None)
else:
return self.einsum_op_slice_0(q, k, v, 1)
return self.einsum_op_slice_dim0(q, k, v, 1)
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb):
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
if size_mb <= max_tensor_mb:
return self.einsum_op_compvis(q, k, v)
return self.einsum_lowest_level(q, k, v, None, None, None)
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
if div <= q.shape[0]:
return self.einsum_op_slice_0(q, k, v, q.shape[0] // div)
return self.einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
print("warning: untested call to einsum_op_slice_dim0")
return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div)
print("warning: untested call to einsum_op_slice_dim1")
return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1))
def einsum_op_cuda(self, q, k, v):
stats = torch.cuda.memory_stats(q.device)
@ -221,7 +251,7 @@ class CrossAttention(nn.Module):
# Divide factor of safety as there's copying and fragmentation
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
def einsum_op(self, q, k, v):
def get_attention_mem_efficient(self, q, k, v):
if q.device.type == 'cuda':
return self.einsum_op_cuda(q, k, v)
@ -244,8 +274,13 @@ class CrossAttention(nn.Module):
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
r = self.einsum_op(q, k, v)
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
r = self.get_attention_mem_efficient(q, k, v)
hidden_states = rearrange(r, '(b h) n d -> b n (h d)', h=h)
return self.to_out(hidden_states)
class BasicTransformerBlock(nn.Module):

View File

@ -1,3 +1,5 @@
import math
import torch
import torch.nn as nn
from functools import partial
@ -437,6 +439,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
param.requires_grad = False
def forward(self, text, **kwargs):
batch_encoding = self.tokenizer(
text,
truncation=True,
@ -454,6 +457,222 @@ class FrozenCLIPEmbedder(AbstractEncoder):
def encode(self, text, **kwargs):
return self(text, **kwargs)
class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
fragment_weights_key = "fragment_weights"
return_tokens_key = "return_tokens"
def forward(self, text: list, **kwargs):
'''
:param text: A batch of prompt strings, or, a batch of lists of fragments of prompt strings to which different
weights shall be applied.
:param kwargs: If the keyword arg "fragment_weights" is passed, it shall contain a batch of lists of weights
for the prompt fragments. In this case text must contain batches of lists of prompt fragments.
:return: A tensor of shape (B, 77, 768) containing weighted embeddings
'''
if self.fragment_weights_key not in kwargs:
# fallback to base class implementation
return super().forward(text, **kwargs)
fragment_weights = kwargs[self.fragment_weights_key]
# self.transformer doesn't like receiving "fragment_weights" as an argument
kwargs.pop(self.fragment_weights_key)
should_return_tokens = False
if self.return_tokens_key in kwargs:
should_return_tokens = kwargs.get(self.return_tokens_key, False)
# self.transformer doesn't like having extra kwargs
kwargs.pop(self.return_tokens_key)
batch_z = None
batch_tokens = None
for fragments, weights in zip(text, fragment_weights):
# First, weight tokens in individual fragments by scaling the feature vectors as requested (effectively
# applying a multiplier to the CFG scale on a per-token basis).
# For tokens weighted<1, intuitively we want SD to become not merely *less* interested in the concept
# captured by the fragment but actually *dis*interested in it (a 0.01 interest in "red" is still an active
# interest, however small, in redness; what the user probably intends when they attach the number 0.01 to
# "red" is to tell SD that it should almost completely *ignore* redness).
# To do this, the embedding is lerped away from base_embedding in the direction of an embedding for a prompt
# string from which the low-weighted fragment has been simply removed. The closer the weight is to zero, the
# closer the resulting embedding is to an embedding for a prompt that simply lacks this fragment.
# handle weights >=1
tokens, per_token_weights = self.get_tokens_and_weights(fragments, weights)
base_embedding = self.build_weighted_embedding_tensor(tokens, per_token_weights, **kwargs)
# this is our starting point
embeddings = base_embedding.unsqueeze(0)
per_embedding_weights = [1.0]
# now handle weights <1
# Do this by building extra embeddings tensors that lack the words being <1 weighted. These will be lerped
# with the embeddings tensors that have the words, such that if the weight of a word is 0.5, the resulting
# embedding will be exactly half-way between the unweighted prompt and the prompt with the <1 weighted words
# removed.
# eg for "mountain:1 man:0.5", intuitively the "man" should be "half-gone". therefore, append an embedding
# for "mountain" (i.e. without "man") to the already-produced embedding for "mountain man", and weight it
# such that the resulting lerped embedding is exactly half-way between "mountain man" and "mountain".
for index, fragment_weight in enumerate(weights):
if fragment_weight < 1:
fragments_without_this = fragments[:index] + fragments[index+1:]
weights_without_this = weights[:index] + weights[index+1:]
tokens, per_token_weights = self.get_tokens_and_weights(fragments_without_this, weights_without_this)
embedding_without_this = self.build_weighted_embedding_tensor(tokens, per_token_weights, **kwargs)
embeddings = torch.cat((embeddings, embedding_without_this.unsqueeze(0)), dim=1)
# weight of the embedding *without* this fragment gets *stronger* as its weight approaches 0
# if fragment_weight = 0, basically we want embedding_without_this to completely overwhelm base_embedding
# therefore:
# fragment_weight = 1: we are at base_z => lerp weight 0
# fragment_weight = 0.5: we are halfway between base_z and here => lerp weight 1
# fragment_weight = 0: we're now entirely overriding base_z ==> lerp weight inf
# so let's use tan(), because:
# tan is 0.0 at 0,
# 1.0 at PI/4, and
# inf at PI/2
# -> tan((1-weight)*PI/2) should give us ideal lerp weights
epsilon = 1e-9
fragment_weight = max(epsilon, fragment_weight) # inf is bad
embedding_lerp_weight = math.tan((1.0 - fragment_weight) * math.pi / 2)
# todo handle negative weight?
per_embedding_weights.append(embedding_lerp_weight)
lerped_embeddings = self.apply_embedding_weights(embeddings, per_embedding_weights, normalize=True).squeeze(0)
#print(f"assembled tokens for '{fragments}' into tensor of shape {lerped_embeddings.shape}")
# append to batch
batch_z = lerped_embeddings.unsqueeze(0) if batch_z is None else torch.cat([batch_z, lerped_embeddings.unsqueeze(0)], dim=1)
batch_tokens = tokens.unsqueeze(0) if batch_tokens is None else torch.cat([batch_tokens, tokens.unsqueeze(0)], dim=1)
# should have shape (B, 77, 768)
#print(f"assembled all tokens into tensor of shape {batch_z.shape}")
if should_return_tokens:
return batch_z, batch_tokens
else:
return batch_z
def get_tokens(self, fragments: list[str], include_start_and_end_markers: bool = True) -> list[list[int]]:
tokens = self.tokenizer(
fragments,
truncation=True,
max_length=self.max_length,
return_overflowing_tokens=False,
padding='do_not_pad',
return_tensors=None, # just give me a list of ints
)['input_ids']
if include_start_and_end_markers:
return tokens
else:
return [x[1:-1] for x in tokens]
@classmethod
def apply_embedding_weights(self, embeddings: torch.Tensor, per_embedding_weights: list[float], normalize:bool) -> torch.Tensor:
per_embedding_weights = torch.tensor(per_embedding_weights, dtype=embeddings.dtype, device=embeddings.device)
if normalize:
per_embedding_weights = per_embedding_weights / torch.sum(per_embedding_weights)
reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1, 1,))
#reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1,1,)).expand(embeddings.shape)
return torch.sum(embeddings * reshaped_weights, dim=1)
# lerped embeddings has shape (77, 768)
def get_tokens_and_weights(self, fragments: list[str], weights: list[float]) -> (torch.Tensor, torch.Tensor):
'''
:param fragments:
:param weights: Per-fragment weights (CFG scaling). No need for these to be normalized. They will not be normalized here and that's fine.
:return:
'''
# empty is meaningful
if len(fragments) == 0 and len(weights) == 0:
fragments = ['']
weights = [1]
item_encodings = self.tokenizer(
fragments,
truncation=True,
max_length=self.max_length,
return_overflowing_tokens=False,
padding='do_not_pad',
return_tensors=None, # just give me a list of ints
)['input_ids']
all_tokens = []
per_token_weights = []
#print("all fragments:", fragments, weights)
for index, fragment in enumerate(item_encodings):
weight = weights[index]
#print("processing fragment", fragment, weight)
fragment_tokens = item_encodings[index]
#print("fragment", fragment, "processed to", fragment_tokens)
# trim bos and eos markers before appending
all_tokens.extend(fragment_tokens[1:-1])
per_token_weights.extend([weight] * (len(fragment_tokens) - 2))
if (len(all_tokens) + 2) > self.max_length:
excess_token_count = (len(all_tokens) + 2) - self.max_length
print(f"prompt is {excess_token_count} token(s) too long and has been truncated")
all_tokens = all_tokens[:self.max_length - 2]
# pad out to a 77-entry array: [eos_token, <prompt tokens>, eos_token, ..., eos_token]
# (77 = self.max_length)
pad_length = self.max_length - 1 - len(all_tokens)
all_tokens.insert(0, self.tokenizer.bos_token_id)
all_tokens.extend([self.tokenizer.eos_token_id] * pad_length)
per_token_weights.insert(0, 1)
per_token_weights.extend([1] * pad_length)
all_tokens_tensor = torch.tensor(all_tokens, dtype=torch.long).to(self.device)
per_token_weights_tensor = torch.tensor(per_token_weights, dtype=torch.float32).to(self.device)
#print(f"assembled all_tokens_tensor with shape {all_tokens_tensor.shape}")
return all_tokens_tensor, per_token_weights_tensor
def build_weighted_embedding_tensor(self, tokens: torch.Tensor, per_token_weights: torch.Tensor, weight_delta_from_empty=True, **kwargs) -> torch.Tensor:
'''
Build a tensor representing the passed-in tokens, each of which has a weight.
:param tokens: A tensor of shape (77) containing token ids (integers)
:param per_token_weights: A tensor of shape (77) containing weights (floats)
:param method: Whether to multiply the whole feature vector for each token or just its distance from an "empty" feature vector
:param kwargs: passed on to self.transformer()
:return: A tensor of shape (1, 77, 768) representing the requested weighted embeddings.
'''
#print(f"building weighted embedding tensor for {tokens} with weights {per_token_weights}")
z = self.transformer(input_ids=tokens.unsqueeze(0), **kwargs)
batch_weights_expanded = per_token_weights.reshape(per_token_weights.shape + (1,)).expand(z.shape)
if weight_delta_from_empty:
empty_tokens = self.tokenizer([''] * z.shape[0],
truncation=True,
max_length=self.max_length,
padding='max_length',
return_tensors='pt'
)['input_ids'].to(self.device)
empty_z = self.transformer(input_ids=empty_tokens, **kwargs)
z_delta_from_empty = z - empty_z
weighted_z = empty_z + (z_delta_from_empty * batch_weights_expanded)
weighted_z_delta_from_empty = (weighted_z-empty_z)
#print("weighted z has delta from empty with sum", weighted_z_delta_from_empty.sum().item(), "mean", weighted_z_delta_from_empty.mean().item() )
#print("using empty-delta method, first 5 rows:")
#print(weighted_z[:5])
return weighted_z
else:
original_mean = z.mean()
z *= batch_weights_expanded
after_weighting_mean = z.mean()
# correct the mean. not sure if this is right but it's what the automatic1111 fork of SD does
mean_correction_factor = original_mean/after_weighting_mean
z *= mean_correction_factor
return z
class FrozenCLIPTextEmbedder(nn.Module):
"""

401
tests/test_prompt_parser.py Normal file
View File

@ -0,0 +1,401 @@
import unittest
import pyparsing
from ldm.invoke.prompt_parser import PromptParser, Blend, Conjunction, FlattenedPrompt, CrossAttentionControlSubstitute, \
Fragment
def parse_prompt(prompt_string):
pp = PromptParser()
#print(f"parsing '{prompt_string}'")
parse_result = pp.parse_conjunction(prompt_string)
#print(f"-> parsed '{prompt_string}' to {parse_result}")
return parse_result
def make_basic_conjunction(strings: list[str]):
fragments = [Fragment(x) for x in strings]
return Conjunction([FlattenedPrompt(fragments)])
def make_weighted_conjunction(weighted_strings: list[tuple[str,float]]):
fragments = [Fragment(x, w) for x,w in weighted_strings]
return Conjunction([FlattenedPrompt(fragments)])
class PromptParserTestCase(unittest.TestCase):
def test_empty(self):
self.assertEqual(make_weighted_conjunction([('', 1)]), parse_prompt(''))
def test_basic(self):
self.assertEqual(make_weighted_conjunction([('fire flames', 1)]), parse_prompt("fire (flames)"))
self.assertEqual(make_weighted_conjunction([("fire flames", 1)]), parse_prompt("fire flames"))
self.assertEqual(make_weighted_conjunction([("fire, flames", 1)]), parse_prompt("fire, flames"))
self.assertEqual(make_weighted_conjunction([("fire, flames , fire", 1)]), parse_prompt("fire, flames , fire"))
def test_attention(self):
self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("(flames)0.5"))
self.assertEqual(make_weighted_conjunction([('fire flames', 0.5)]), parse_prompt("(fire flames)0.5"))
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("(flames)+"))
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("flames+"))
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("\"flames\"+"))
self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("(flames)-"))
self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("flames-"))
self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("\"flames\"-"))
self.assertEqual(make_weighted_conjunction([('fire', 1), ('flames', 0.5)]), parse_prompt("fire (flames)0.5"))
self.assertEqual(make_weighted_conjunction([('flames', pow(1.1, 2))]), parse_prompt("(flames)++"))
self.assertEqual(make_weighted_conjunction([('flames', pow(0.9, 2))]), parse_prompt("(flames)--"))
self.assertEqual(make_weighted_conjunction([('flowers', pow(0.9, 3)), ('flames', pow(1.1, 3))]), parse_prompt("(flowers)--- flames+++"))
self.assertEqual(make_weighted_conjunction([('pretty flowers', 1.1)]),
parse_prompt("(pretty flowers)+"))
self.assertEqual(make_weighted_conjunction([('pretty flowers', 1.1), (', the flames are too hot', 1)]),
parse_prompt("(pretty flowers)+, the flames are too hot"))
def test_no_parens_attention_runon(self):
self.assertEqual(make_weighted_conjunction([('fire', 1.0), ('flames', pow(1.1, 2))]), parse_prompt("fire flames++"))
self.assertEqual(make_weighted_conjunction([('fire', 1.0), ('flames', pow(0.9, 2))]), parse_prompt("fire flames--"))
self.assertEqual(make_weighted_conjunction([('flowers', 1.0), ('fire', pow(1.1, 2)), ('flames', 1.0)]), parse_prompt("flowers fire++ flames"))
self.assertEqual(make_weighted_conjunction([('flowers', 1.0), ('fire', pow(0.9, 2)), ('flames', 1.0)]), parse_prompt("flowers fire-- flames"))
def test_explicit_conjunction(self):
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])]), parse_prompt('("fire", "flames").and(1,1)'))
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])]), parse_prompt('("fire", "flames").and()'))
self.assertEqual(
Conjunction([FlattenedPrompt([('fire flames', 1.0)]), FlattenedPrompt([('mountain man', 1.0)])]), parse_prompt('("fire flames", "mountain man").and()'))
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 2.0)]), FlattenedPrompt([('flames', 0.9)])]), parse_prompt('("(fire)2.0", "flames-").and()'))
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)]),
FlattenedPrompt([('mountain man', 1.0)])]), parse_prompt('("fire", "flames", "mountain man").and()'))
def test_conjunction_weights(self):
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])], weights=[2.0,1.0]), parse_prompt('("fire", "flames").and(2,1)'))
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])], weights=[1.0,2.0]), parse_prompt('("fire", "flames").and(1,2)'))
with self.assertRaises(PromptParser.ParsingException):
parse_prompt('("fire", "flames").and(2)')
parse_prompt('("fire", "flames").and(2,1,2)')
def test_complex_conjunction(self):
#print(parse_prompt("a person with a hat (riding a bicycle.swap(skateboard))++"))
self.assertEqual(Conjunction([FlattenedPrompt([("mountain man", 1.0)]), FlattenedPrompt([("a person with a hat", 1.0), ("riding a bicycle", pow(1.1,2))])], weights=[0.5, 0.5]),
parse_prompt("(\"mountain man\", \"a person with a hat (riding a bicycle)++\").and(0.5, 0.5)"))
self.assertEqual(Conjunction([FlattenedPrompt([("mountain man", 1.0)]),
FlattenedPrompt([("a person with a hat", 1.0),
("riding a", 1.1*1.1),
CrossAttentionControlSubstitute(
[Fragment("bicycle", pow(1.1,2))],
[Fragment("skateboard", pow(1.1,2))])
])
], weights=[0.5, 0.5]),
parse_prompt("(\"mountain man\", \"a person with a hat (riding a bicycle.swap(skateboard))++\").and(0.5, 0.5)"))
def test_badly_formed(self):
def make_untouched_prompt(prompt):
return Conjunction([FlattenedPrompt([(prompt, 1.0)])])
def assert_if_prompt_string_not_untouched(prompt):
self.assertEqual(make_untouched_prompt(prompt), parse_prompt(prompt))
assert_if_prompt_string_not_untouched('a test prompt')
# todo handle this
#assert_if_prompt_string_not_untouched('a badly formed +test prompt')
with self.assertRaises(pyparsing.ParseException):
parse_prompt('a badly (formed test prompt')
#with self.assertRaises(pyparsing.ParseException):
with self.assertRaises(pyparsing.ParseException):
parse_prompt('a badly (formed +test prompt')
with self.assertRaises(pyparsing.ParseException):
parse_prompt('a badly (formed +test )prompt')
with self.assertRaises(pyparsing.ParseException):
parse_prompt('a badly (formed +test )prompt')
with self.assertRaises(pyparsing.ParseException):
parse_prompt('(((a badly (formed +test )prompt')
with self.assertRaises(pyparsing.ParseException):
parse_prompt('(a (ba)dly (f)ormed +test prompt')
with self.assertRaises(pyparsing.ParseException):
parse_prompt('(a (ba)dly (f)ormed +test +prompt')
with self.assertRaises(pyparsing.ParseException):
parse_prompt('("((a badly (formed +test ").blend(1.0)')
def test_blend(self):
self.assertEqual(Conjunction(
[Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('fire flames', 1.0)])], [0.7, 0.3])]),
parse_prompt("(\"fire\", \"fire flames\").blend(0.7, 0.3)")
)
self.assertEqual(Conjunction([Blend(
[FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('fire flames', 1.0)]), FlattenedPrompt([('hi', 1.0)])],
[0.7, 0.3, 1.0])]),
parse_prompt("(\"fire\", \"fire flames\", \"hi\").blend(0.7, 0.3, 1.0)")
)
self.assertEqual(Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]),
FlattenedPrompt([('fire flames', 1.0), ('hot', pow(1.1, 2))]),
FlattenedPrompt([('hi', 1.0)])],
weights=[0.7, 0.3, 1.0])]),
parse_prompt("(\"fire\", \"fire flames (hot)++\", \"hi\").blend(0.7, 0.3, 1.0)")
)
# blend a single entry is not a failure
self.assertEqual(Conjunction([Blend([FlattenedPrompt([('fire', 1.0)])], [0.7])]),
parse_prompt("(\"fire\").blend(0.7)")
)
# blend with empty
self.assertEqual(
Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]),
parse_prompt("(\"fire\", \"\").blend(0.7, 1)")
)
self.assertEqual(
Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]),
parse_prompt("(\"fire\", \" \").blend(0.7, 1)")
)
self.assertEqual(
Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]),
parse_prompt("(\"fire\", \" \").blend(0.7, 1)")
)
self.assertEqual(
Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([(',', 1.0)])], [0.7, 1.0])]),
parse_prompt("(\"fire\", \" , \").blend(0.7, 1)")
)
self.assertEqual(
Conjunction([Blend([FlattenedPrompt([('mountain, man, hairy', 1)]),
FlattenedPrompt([('face, teeth,', 1), ('eyes', 0.9*0.9)])], weights=[1.0,-1.0])]),
parse_prompt('("mountain, man, hairy", "face, teeth, eyes--").blend(1,-1)')
)
def test_nested(self):
self.assertEqual(make_weighted_conjunction([('fire', 1.0), ('flames', 2.0), ('trees', 3.0)]),
parse_prompt('fire (flames (trees)1.5)2.0'))
self.assertEqual(Conjunction([Blend(prompts=[FlattenedPrompt([('fire', 1.0), ('flames', 1.2100000000000002)]),
FlattenedPrompt([('mountain', 1.0), ('man', 2.0)])],
weights=[1.0, 1.0])]),
parse_prompt('("fire (flames)++", "mountain (man)2").blend(1,1)'))
def test_cross_attention_control(self):
self.assertEqual(Conjunction([
FlattenedPrompt([Fragment('a', 1),
CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)]),
Fragment('eating a hotdog', 1)])]), parse_prompt("a \"cat\".swap(dog) eating a hotdog"))
self.assertEqual(Conjunction([
FlattenedPrompt([Fragment('a', 1),
CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)]),
Fragment('eating a hotdog', 1)])]), parse_prompt("a cat.swap(dog) eating a hotdog"))
fire_flames_to_trees = Conjunction([FlattenedPrompt([('fire', 1.0), \
CrossAttentionControlSubstitute([Fragment('flames', 1)], [Fragment('trees', 1)])])])
self.assertEqual(fire_flames_to_trees, parse_prompt('fire "flames".swap(trees)'))
self.assertEqual(fire_flames_to_trees, parse_prompt('fire (flames).swap(trees)'))
self.assertEqual(fire_flames_to_trees, parse_prompt('fire ("flames").swap(trees)'))
self.assertEqual(fire_flames_to_trees, parse_prompt('fire "flames".swap("trees")'))
self.assertEqual(fire_flames_to_trees, parse_prompt('fire (flames).swap("trees")'))
self.assertEqual(fire_flames_to_trees, parse_prompt('fire ("flames").swap("trees")'))
fire_flames_to_trees_and_houses = Conjunction([FlattenedPrompt([('fire', 1.0), \
CrossAttentionControlSubstitute([Fragment('flames', 1)], [Fragment('trees and houses', 1)])])])
self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire ("flames").swap("trees and houses")'))
self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire (flames).swap("trees and houses")'))
self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire "flames".swap("trees and houses")'))
trees_and_houses_to_flames = Conjunction([FlattenedPrompt([('fire', 1.0), \
CrossAttentionControlSubstitute([Fragment('trees and houses', 1)], [Fragment('flames',1)])])])
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire ("trees and houses").swap("flames")'))
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire (trees and houses).swap("flames")'))
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire "trees and houses".swap("flames")'))
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire ("trees and houses").swap(flames)'))
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire (trees and houses).swap(flames)'))
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire "trees and houses".swap(flames)'))
flames_to_trees_fire = Conjunction([FlattenedPrompt([
CrossAttentionControlSubstitute([Fragment('flames',1)], [Fragment('trees',1)]),
(', fire', 1.0)])])
self.assertEqual(flames_to_trees_fire, parse_prompt('"flames".swap("trees"), fire'))
self.assertEqual(flames_to_trees_fire, parse_prompt('(flames).swap("trees"), fire'))
self.assertEqual(flames_to_trees_fire, parse_prompt('("flames").swap("trees"), fire'))
self.assertEqual(flames_to_trees_fire, parse_prompt('"flames".swap(trees), fire'))
self.assertEqual(flames_to_trees_fire, parse_prompt('(flames).swap(trees), fire '))
self.assertEqual(flames_to_trees_fire, parse_prompt('("flames").swap(trees), fire '))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]),
parse_prompt('a forest landscape "".swap("in winter")'))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]),
parse_prompt('a forest landscape " ".swap("in winter")'))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]),
parse_prompt('a forest landscape "in winter".swap("")'))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]),
parse_prompt('a forest landscape "in winter".swap()'))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]),
parse_prompt('a forest landscape "in winter".swap(" ")'))
def test_cross_attention_control_with_attention(self):
flames_to_trees_fire = Conjunction([FlattenedPrompt([
CrossAttentionControlSubstitute([Fragment('flames',0.5)], [Fragment('trees',0.7)]),
Fragment(',', 1), Fragment('fire', 2.0)])])
self.assertEqual(flames_to_trees_fire, parse_prompt('"(flames)0.5".swap("(trees)0.7"), (fire)2.0'))
flames_to_trees_fire = Conjunction([FlattenedPrompt([
CrossAttentionControlSubstitute([Fragment('fire',0.5), Fragment('flames',0.25)], [Fragment('trees',0.7)]),
Fragment(',', 1), Fragment('fire', 2.0)])])
self.assertEqual(flames_to_trees_fire, parse_prompt('"(fire (flames)0.5)0.5".swap("(trees)0.7"), (fire)2.0'))
flames_to_trees_fire = Conjunction([FlattenedPrompt([
CrossAttentionControlSubstitute([Fragment('fire',0.5), Fragment('flames',0.25)], [Fragment('trees',0.7), Fragment('houses', 1)]),
Fragment(',', 1), Fragment('fire', 2.0)])])
self.assertEqual(flames_to_trees_fire, parse_prompt('"(fire (flames)0.5)0.5".swap("(trees)0.7 houses"), (fire)2.0'))
def test_cross_attention_control_options(self):
self.assertEqual(Conjunction([
FlattenedPrompt([Fragment('a', 1),
CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)], options={'s_start':0.1}),
Fragment('eating a hotdog', 1)])]),
parse_prompt("a \"cat\".swap(dog, s_start=0.1) eating a hotdog"))
self.assertEqual(Conjunction([
FlattenedPrompt([Fragment('a', 1),
CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)], options={'t_start':0.1}),
Fragment('eating a hotdog', 1)])]),
parse_prompt("a \"cat\".swap(dog, t_start=0.1) eating a hotdog"))
self.assertEqual(Conjunction([
FlattenedPrompt([Fragment('a', 1),
CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)], options={'s_start': 20.0, 't_start':0.1}),
Fragment('eating a hotdog', 1)])]),
parse_prompt("a \"cat\".swap(dog, t_start=0.1, s_start=20) eating a hotdog"))
self.assertEqual(
Conjunction([
FlattenedPrompt([Fragment('a fantasy forest landscape', 1),
CrossAttentionControlSubstitute([Fragment('', 1)], [Fragment('with a river', 1)],
options={'s_start': 0.8, 't_start': 0.8})])]),
parse_prompt("a fantasy forest landscape \"\".swap(with a river, s_start=0.8, t_start=0.8)"))
def test_escaping(self):
# make sure ", ( and ) can be escaped
self.assertEqual(make_basic_conjunction(['mountain (man)']),parse_prompt('mountain \(man\)'))
self.assertEqual(make_basic_conjunction(['mountain (man )']),parse_prompt('mountain (\(man)\)'))
self.assertEqual(make_basic_conjunction(['mountain (man)']),parse_prompt('mountain (\(man\))'))
self.assertEqual(make_weighted_conjunction([('mountain', 1), ('(man)', 1.1)]), parse_prompt('mountain (\(man\))+'))
self.assertEqual(make_weighted_conjunction([('mountain', 1), ('(man)', 1.1)]), parse_prompt('"mountain" (\(man\))+'))
self.assertEqual(make_weighted_conjunction([('"mountain"', 1), ('(man)', 1.1)]), parse_prompt('\\"mountain\\" (\(man\))+'))
# same weights for each are combined into one
self.assertEqual(make_weighted_conjunction([('"mountain" (man)', 1.1)]), parse_prompt('(\\"mountain\\")+ (\(man\))+'))
self.assertEqual(make_weighted_conjunction([('"mountain"', 1.1), ('(man)', 0.9)]), parse_prompt('(\\"mountain\\")+ (\(man\))-'))
self.assertEqual(make_weighted_conjunction([('mountain', 1), ('\(man\)', 1.1)]),parse_prompt('mountain (\(man\))1.1'))
self.assertEqual(make_weighted_conjunction([('mountain', 1), ('\(man\)', 1.1)]),parse_prompt('"mountain" (\(man\))1.1'))
self.assertEqual(make_weighted_conjunction([('"mountain"', 1), ('\(man\)', 1.1)]),parse_prompt('\\"mountain\\" (\(man\))1.1'))
# same weights for each are combined into one
self.assertEqual(make_weighted_conjunction([('\\"mountain\\" \(man\)', 1.1)]),parse_prompt('(\\"mountain\\")+ (\(man\))1.1'))
self.assertEqual(make_weighted_conjunction([('\\"mountain\\"', 1.1), ('\(man\)', 0.9)]),parse_prompt('(\\"mountain\\")1.1 (\(man\))0.9'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain', 1.1), ('\(man\)', 1.1*1.1)]),parse_prompt('hairy (mountain (\(man\))+)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('\(man\)', 1.1*1.1), ('mountain', 1.1)]),parse_prompt('hairy ((\(man\))1.1 "mountain")+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain', 1.1), ('\(man\)', 1.1*1.1)]),parse_prompt('hairy ("mountain" (\(man\))1.1 )+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, man', 1.1)]),parse_prompt('hairy ("mountain, man")+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, man with a', 1.1), ('beard', 1.1*1.1)]), parse_prompt('hairy ("mountain, man" with a beard+)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, man with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, man" with a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man\" with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\"man\\"" with a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, m\"an\" with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, m\\"an\\"" with a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man (with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" \(with a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man w(ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" w\(ith a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man with( a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" with\( a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man )with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" \)with a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" w\)ith a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man with) a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" with\) a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mou)ntain, \"man (wit(h a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mou\)ntain, \\\"man\" \(wit\(h a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hai(ry', 1), ('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hai\(ry ("mountain, \\\"man\" w\)ith a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy((', 1), ('mountain, \"man with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy\(\( ("mountain, \\\"man\" with a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('mountain, \"man (with a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" \(with a (beard)2.0)+ hairy'))
self.assertEqual(make_weighted_conjunction([('mountain, \"man w(ith a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" w\(ith a (beard)2.0)+hairy'))
self.assertEqual(make_weighted_conjunction([('mountain, \"man with( a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" with\( a (beard)2.0)+ hairy'))
self.assertEqual(make_weighted_conjunction([('mountain, \"man )with a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" \)with a (beard)2.0)+ hairy'))
self.assertEqual(make_weighted_conjunction([('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" w\)ith a (beard)2.0)+ hairy'))
self.assertEqual(make_weighted_conjunction([('mountain, \"man with) a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt(' ("mountain, \\\"man\" with\) a (beard)2.0)+ hairy'))
self.assertEqual(make_weighted_conjunction([('mou)ntain, \"man (wit(h a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mou\)ntain, \\\"man\" \(wit\(h a (beard)2.0)+ hairy'))
self.assertEqual(make_weighted_conjunction([('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0), ('hai(ry', 1)]), parse_prompt('("mountain, \\\"man\" w\)ith a (beard)2.0)+ hai\(ry '))
self.assertEqual(make_weighted_conjunction([('mountain, \"man with a', 1.1), ('beard', 1.1*2.0), ('hairy((', 1)]), parse_prompt('("mountain, \\\"man\" with a (beard)2.0)+ hairy\(\( '))
def test_cross_attention_escaping(self):
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('monkey', 1)])])]),
parse_prompt('mountain (man).swap(monkey)'))
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('m(onkey', 1)])])]),
parse_prompt('mountain (man).swap(m\(onkey)'))
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('m(an', 1)], [Fragment('m(onkey', 1)])])]),
parse_prompt('mountain (m\(an).swap(m\(onkey)'))
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('(((', 1)], [Fragment('m(on))key', 1)])])]),
parse_prompt('mountain (\(\(\().swap(m\(on\)\)key)'))
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('monkey', 1)])])]),
parse_prompt('mountain ("man").swap(monkey)'))
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('monkey', 1)])])]),
parse_prompt('mountain ("man").swap("monkey")'))
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('"man', 1)], [Fragment('monkey', 1)])])]),
parse_prompt('mountain (\\"man).swap("monkey")'))
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('m(onkey', 1)])])]),
parse_prompt('mountain (man).swap(m\(onkey)'))
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('m(an', 1)], [Fragment('m(onkey', 1)])])]),
parse_prompt('mountain (m\(an).swap(m\(onkey)'))
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('(((', 1)], [Fragment('m(on))key', 1)])])]),
parse_prompt('mountain (\(\(\().swap(m\(on\)\)key)'))
def test_legacy_blend(self):
pp = PromptParser()
self.assertEqual(Blend([FlattenedPrompt([('mountain man', 1)]),
FlattenedPrompt([('man mountain', 1)])],
weights=[0.5,0.5]),
pp.parse_legacy_blend('mountain man:1 man mountain:1'))
self.assertEqual(Blend([FlattenedPrompt([('mountain', 1.1), ('man', 1)]),
FlattenedPrompt([('man', 1), ('mountain', 0.9)])],
weights=[0.5,0.5]),
pp.parse_legacy_blend('mountain+ man:1 man mountain-:1'))
self.assertEqual(Blend([FlattenedPrompt([('mountain', 1.1), ('man', 1)]),
FlattenedPrompt([('man', 1), ('mountain', 0.9)])],
weights=[0.5,0.5]),
pp.parse_legacy_blend('mountain+ man:1 man mountain-'))
self.assertEqual(Blend([FlattenedPrompt([('mountain', 1.1), ('man', 1)]),
FlattenedPrompt([('man', 1), ('mountain', 0.9)])],
weights=[0.5,0.5]),
pp.parse_legacy_blend('mountain+ man: man mountain-:'))
self.assertEqual(Blend([FlattenedPrompt([('mountain man', 1)]),
FlattenedPrompt([('man mountain', 1)])],
weights=[0.75,0.25]),
pp.parse_legacy_blend('mountain man:3 man mountain:1'))
self.assertEqual(Blend([FlattenedPrompt([('mountain man', 1)]),
FlattenedPrompt([('man mountain', 1)])],
weights=[1.0,0.0]),
pp.parse_legacy_blend('mountain man:3 man mountain:0'))
self.assertEqual(Blend([FlattenedPrompt([('mountain man', 1)]),
FlattenedPrompt([('man mountain', 1)])],
weights=[0.8,0.2]),
pp.parse_legacy_blend('"mountain man":4 man mountain'))
def test_single(self):
# todo handle this
#self.assertEqual(make_basic_conjunction(['a badly formed +test prompt']),
# parse_prompt('a badly formed +test prompt'))
pass
if __name__ == '__main__':
unittest.main()