add prompt language support for cross-attention .swap

This commit is contained in:
Damian at mba 2022-10-20 01:42:04 +02:00
parent 1ffd4a9e06
commit 42883545f9
6 changed files with 585 additions and 50 deletions

View File

@ -35,7 +35,7 @@ from ldm.invoke.devices import choose_torch_device, choose_precision
from ldm.invoke.conditioning import get_uc_and_c_and_ec from ldm.invoke.conditioning import get_uc_and_c_and_ec
from ldm.invoke.model_cache import ModelCache from ldm.invoke.model_cache import ModelCache
from ldm.invoke.seamless import configure_model_padding from ldm.invoke.seamless import configure_model_padding
from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale #from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale
def fix_func(orig): def fix_func(orig):
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():

View File

@ -11,71 +11,93 @@ log_tokenization() print out colour-coded tokens and warn if trunca
''' '''
import re import re
from difflib import SequenceMatcher from difflib import SequenceMatcher
from typing import Union
import torch import torch
def get_uc_and_c_and_ec(prompt, model, log_tokens=False, skip_normalize=False): from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \
CrossAttentionControlledFragment, CrossAttentionControlSubstitute, CrossAttentionControlAppend
from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder
def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_normalize=False):
# Extract Unconditioned Words From Prompt # Extract Unconditioned Words From Prompt
unconditioned_words = '' unconditioned_words = ''
unconditional_regex = r'\[(.*?)\]' unconditional_regex = r'\[(.*?)\]'
unconditionals = re.findall(unconditional_regex, prompt) unconditionals = re.findall(unconditional_regex, prompt_string_uncleaned)
if len(unconditionals) > 0: if len(unconditionals) > 0:
unconditioned_words = ' '.join(unconditionals) unconditioned_words = ' '.join(unconditionals)
# Remove Unconditioned Words From Prompt # Remove Unconditioned Words From Prompt
unconditional_regex_compile = re.compile(unconditional_regex) unconditional_regex_compile = re.compile(unconditional_regex)
clean_prompt = unconditional_regex_compile.sub(' ', prompt) clean_prompt = unconditional_regex_compile.sub(' ', prompt_string_uncleaned)
prompt = re.sub(' +', ' ', clean_prompt) prompt_string_cleaned = re.sub(' +', ' ', clean_prompt)
else:
prompt_string_cleaned = prompt_string_uncleaned
edited_words = None pp = PromptParser()
edited_regex = r'\{(.*?)\}'
edited = re.findall(edited_regex, prompt)
if len(edited) > 0:
edited_words = ' '.join(edited)
edited_regex_compile = re.compile(edited_regex)
clean_prompt = edited_regex_compile.sub(' ', prompt)
prompt = re.sub(' +', ' ', clean_prompt)
# get weighted sub-prompts parsed_prompt: Union[FlattenedPrompt, Blend] = pp.parse(prompt_string_cleaned)
weighted_subprompts = split_weighted_subprompts( parsed_negative_prompt: FlattenedPrompt = pp.parse(unconditioned_words)
prompt, skip_normalize
)
ec = None conditioning = None
edited_conditioning = None
edit_opcodes = None edit_opcodes = None
uc, _ = model.get_learned_conditioning([unconditioned_words]) if parsed_prompt is Blend:
blend: Blend = parsed_prompt
embeddings_to_blend = None
for flattened_prompt in blend.prompts:
this_embedding = make_embeddings_for_flattened_prompt(model, flattened_prompt)
embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat(
(embeddings_to_blend, this_embedding))
conditioning, _ = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0),
blend.weights,
normalize=blend.normalize_weights)
else:
flattened_prompt: FlattenedPrompt = parsed_prompt
wants_cross_attention_control = any([issubclass(type(x), CrossAttentionControlledFragment) for x in flattened_prompt.children])
if wants_cross_attention_control:
original_prompt = FlattenedPrompt()
edited_prompt = FlattenedPrompt()
for fragment in flattened_prompt.children:
if type(fragment) is CrossAttentionControlSubstitute:
original_prompt.append(fragment.original_fragment)
edited_prompt.append(fragment.edited_fragment)
elif type(fragment) is CrossAttentionControlAppend:
edited_prompt.append(fragment.fragment)
else:
# regular fragment
original_prompt.append(fragment)
edited_prompt.append(fragment)
original_embeddings, original_tokens = make_embeddings_for_flattened_prompt(model, original_prompt)
edited_embeddings, edited_tokens = make_embeddings_for_flattened_prompt(model, edited_prompt)
if len(weighted_subprompts) > 1: conditioning = original_embeddings
# i dont know if this is correct.. but it works edited_conditioning = edited_embeddings
c = torch.zeros_like(uc) edit_opcodes = build_token_edit_opcodes(original_tokens, edited_tokens)
# normalize each "sub prompt" and add it else:
for subprompt, weight in weighted_subprompts: conditioning, _ = make_embeddings_for_flattened_prompt(model, flattened_prompt)
log_tokenization(subprompt, model, log_tokens, weight)
subprompt_embeddings, _ = model.get_learned_conditioning([subprompt])
c = torch.add(
c,
subprompt_embeddings,
alpha=weight,
)
if edited_words is not None:
print("can't do cross-attention control with blends just yet, ignoring edits")
else: # just standard 1 prompt
log_tokenization(prompt, model, log_tokens, 1)
c, c_tokens = model.get_learned_conditioning([prompt])
if edited_words is not None:
ec, ec_tokens = model.get_learned_conditioning([edited_words])
edit_opcodes = build_token_edit_opcodes(c_tokens, ec_tokens)
return (uc, c, ec, edit_opcodes) unconditioning = make_embeddings_for_flattened_prompt(parsed_negative_prompt)
return (unconditioning, conditioning, edited_conditioning, edit_opcodes)
def build_token_edit_opcodes(c_tokens, ec_tokens):
tokens = c_tokens.cpu().numpy()[0]
tokens_edit = ec_tokens.cpu().numpy()[0]
opcodes = SequenceMatcher(None, tokens, tokens_edit).get_opcodes() def build_token_edit_opcodes(original_tokens, edited_tokens):
return opcodes original_tokens = original_tokens.cpu().numpy()[0]
edited_tokens = edited_tokens.cpu().numpy()[0]
return SequenceMatcher(None, original_tokens, edited_tokens).get_opcodes()
def make_embeddings_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt):
if type(flattened_prompt) is not FlattenedPrompt:
raise f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead"
fragments = [x[0] for x in flattened_prompt.children]
embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True)
return embeddings, tokens
def split_weighted_subprompts(text, skip_normalize=False)->list: def split_weighted_subprompts(text, skip_normalize=False)->list:

326
ldm/invoke/prompt_parser.py Normal file
View File

@ -0,0 +1,326 @@
import pyparsing
import pyparsing as pp
from pyparsing import original_text_for
class Prompt():
def __init__(self, parts: list):
for c in parts:
if not issubclass(type(c), BaseFragment):
raise PromptParser.ParsingException(f"Prompt cannot contain {type(c)}, only {BaseFragment.__subclasses__()} are allowed")
self.children = parts
def __repr__(self):
return f"Prompt:{self.children}"
def __eq__(self, other):
return type(other) is Prompt and other.children == self.children
class FlattenedPrompt():
def __init__(self, parts: list):
# verify type correctness
parts_converted = []
for part in parts:
if issubclass(type(part), BaseFragment):
parts_converted.append(part)
elif type(part) is tuple:
# upgrade tuples to Fragments
if type(part[0]) is not str or (type(part[1]) is not float and type(part[1]) is not int):
raise PromptParser.ParsingException(
f"FlattenedPrompt cannot contain {part}, only Fragments or (str, float) tuples are allowed")
parts_converted.append(Fragment(part[0], part[1]))
else:
raise PromptParser.ParsingException(
f"FlattenedPrompt cannot contain {part}, only Fragments or (str, float) tuples are allowed")
# all looks good
self.children = parts_converted
def __repr__(self):
return f"FlattenedPrompt:{self.children}"
def __eq__(self, other):
return type(other) is FlattenedPrompt and other.children == self.children
# abstract base class for Fragments
class BaseFragment:
pass
class Fragment(BaseFragment):
def __init__(self, text: str, weight: float=1):
assert(type(text) is str)
self.text = text
self.weight = float(weight)
def __repr__(self):
return "Fragment:'"+self.text+"'@"+str(self.weight)
def __eq__(self, other):
return type(other) is Fragment \
and other.text == self.text \
and other.weight == self.weight
class CrossAttentionControlledFragment(BaseFragment):
pass
class CrossAttentionControlSubstitute(CrossAttentionControlledFragment):
def __init__(self, original: Fragment, edited: Fragment):
self.original = original
self.edited = edited
def __repr__(self):
return f"CrossAttentionControlSubstitute:('{self.original}'->'{self.edited}')"
def __eq__(self, other):
return type(other) is CrossAttentionControlSubstitute \
and other.original == self.original \
and other.edited == self.edited
class CrossAttentionControlAppend(CrossAttentionControlledFragment):
def __init__(self, fragment: Fragment):
self.fragment = fragment
def __repr__(self):
return "CrossAttentionControlAppend:",self.fragment
def __eq__(self, other):
return type(other) is CrossAttentionControlAppend \
and other.fragment == self.fragment
class Conjunction():
def __init__(self, prompts: list, weights: list = None):
# force everything to be a Prompt
#print("making conjunction with", parts)
self.prompts = [x if (type(x) is Prompt
or type(x) is Blend
or type(x) is FlattenedPrompt)
else Prompt(x) for x in prompts]
self.weights = [1.0]*len(self.prompts) if weights is None else list(weights)
if len(self.weights) != len(self.prompts):
raise PromptParser.ParsingException(f"while parsing Conjunction: mismatched parts/weights counts {prompts}, {weights}")
self.type = 'AND'
def __repr__(self):
return f"Conjunction:{self.prompts} | weights {self.weights}"
def __eq__(self, other):
return type(other) is Conjunction \
and other.prompts == self.prompts \
and other.weights == self.weights
class Blend():
def __init__(self, prompts: list, weights: list[float], normalize_weights: bool=True):
#print("making Blend with prompts", prompts, "and weights", weights)
if len(prompts) != len(weights):
raise PromptParser.ParsingException(f"while parsing Blend: mismatched prompts/weights counts {prompts}, {weights}")
for c in prompts:
if type(c) is not Prompt and type(c) is not FlattenedPrompt:
raise(PromptParser.ParsingException(f"{type(c)} cannot be added to a Blend, only Prompts or FlattenedPrompts"))
# upcast all lists to Prompt objects
self.prompts = [x if (type(x) is Prompt or type(x) is FlattenedPrompt)
else Prompt(x) for x in prompts]
self.prompts = prompts
self.weights = weights
self.normalize_weights = normalize_weights
def __repr__(self):
return f"Blend:{self.prompts} | weights {self.weights}"
def __eq__(self, other):
return other.__repr__() == self.__repr__()
class PromptParser():
class ParsingException(Exception):
pass
def __init__(self, attention_plus_base=1.1, attention_minus_base=0.9):
self.attention_plus_base = attention_plus_base
self.attention_minus_base = attention_minus_base
self.root = self.build_parser_logic()
def parse(self, prompt: str) -> [list]:
'''
:param prompt: The prompt string to parse
:return: a tuple
'''
#print(f"!!parsing '{prompt}'")
if len(prompt.strip()) == 0:
return Conjunction(prompts=[FlattenedPrompt([('', 1.0)])], weights=[1.0])
root = self.root.parse_string(prompt)
#print(f"'{prompt}' parsed to root", root)
#fused = fuse_fragments(parts)
#print("fused to", fused)
return self.flatten(root[0])
def flatten(self, root: Conjunction):
def fuse_fragments(items):
# print("fusing fragments in ", items)
result = []
for x in items:
if issubclass(type(x), CrossAttentionControlledFragment):
result.append(x)
else:
last_weight = result[-1].weight \
if (len(result) > 0 and not issubclass(type(result[-1]), CrossAttentionControlledFragment)) \
else None
this_text = x.text
this_weight = x.weight
if last_weight is not None and last_weight == this_weight:
last_text = result[-1].text
result[-1] = Fragment(last_text + ' ' + this_text, last_weight)
else:
result.append(x)
return result
def flatten_internal(node, weight_scale, results, prefix):
#print(prefix + "flattening", node, "...")
if type(node) is pp.ParseResults:
for x in node:
results = flatten_internal(x, weight_scale, results, prefix+'pr')
#print(prefix, " ParseResults expanded, results is now", results)
elif issubclass(type(node), BaseFragment):
results.append(node)
#elif type(node) is Attention:
# #if node.weight < 1:
# # todo: inject a blend when flattening attention with weight <1"
# for c in node.children:
# results = flatten_internal(c, weight_scale*node.weight, results, prefix+' ')
elif type(node) is Blend:
flattened_subprompts = []
#print(" flattening blend with prompts", node.prompts, "weights", node.weights)
for prompt in node.prompts:
# prompt is a list
flattened_subprompts = flatten_internal(prompt, weight_scale, flattened_subprompts, prefix+'B ')
results += [Blend(prompts=flattened_subprompts, weights=node.weights)]
elif type(node) is Prompt:
#print(prefix + "about to flatten Prompt with children", node.children)
flattened_prompt = []
for child in node.children:
flattened_prompt = flatten_internal(child, weight_scale, flattened_prompt, prefix+'P ')
results += [FlattenedPrompt(parts=fuse_fragments(flattened_prompt))]
#print(prefix + "after flattening Prompt, results is", results)
else:
raise PromptParser.ParsingException(f"unhandled node type {type(node)} when flattening {node}")
#print(prefix + "-> after flattening", type(node), "results is", results)
return results
#print("flattening", root)
flattened_parts = []
for part in root.prompts:
flattened_parts += flatten_internal(part, 1.0, [], ' C| ')
weights = root.weights
return Conjunction(flattened_parts, weights)
def build_parser_logic(self):
lparen = pp.Literal("(").suppress()
rparen = pp.Literal(")").suppress()
# accepts int or float notation, always maps to float
number = pyparsing.pyparsing_common.real | pp.Word(pp.nums).set_parse_action(pp.token_map(float))
SPACE_CHARS = ' \t\n'
prompt_part = pp.Forward()
word = pp.Word(pp.printables).set_parse_action(lambda x: Fragment(' '.join([s for s in x])))
word.set_name("word")
word.set_debug(False)
def make_fragment(x):
#print("### making fragment for", x)
if type(x) is str:
return Fragment(x)
elif type(x) is pp.ParseResults or type(x) is list:
return Fragment(' '.join([s for s in x]))
else:
raise PromptParser.ParsingException("Cannot make fragment from " + str(x))
original_words = (
(lparen + pp.Literal('"').suppress() + pp.CharsNotIn('"') + pp.Literal('"').suppress() + rparen).set_name('term1').set_debug(False) |
(pp.Literal('"').suppress() + pp.CharsNotIn('"') + pp.Literal('"').suppress()).set_name('term2').set_debug(False) |
(lparen + pp.CharsNotIn(')') + rparen).set_name('term3').set_debug(False)
).set_name('original_words')
edited_words = (
(pp.Literal('"').suppress() + pp.CharsNotIn('"') + pp.Literal('"').suppress()).set_name('termA').set_debug(False) |
pp.CharsNotIn(')').set_name('termB').set_debug(False)
).set_name('edited_words')
cross_attention_substitute = original_words + \
pp.Literal(".swap").suppress() + \
lparen + edited_words + rparen
cross_attention_substitute.set_name('cross_attention_substitute')
def make_cross_attention_substitute(x):
#print("making cacs for", x)
return CrossAttentionControlSubstitute(x[0], x[1])
#print("made", cacs)
#return cacs
cross_attention_substitute.set_parse_action(make_cross_attention_substitute)
# simple fragments of text
prompt_part << (cross_attention_substitute
#| attention
| word
)
prompt_part.set_debug(False)
prompt_part.set_name("prompt_part")
# root prompt definition
prompt = pp.Group(pp.OneOrMore(prompt_part))\
.set_parse_action(lambda x: Prompt(x[0]))
# weighted blend of prompts
# ("promptA", "promptB").blend(a, b) where "promptA" and "promptB" are valid prompts and a and b are float or
# int weights.
# can specify more terms eg ("promptA", "promptB", "promptC").blend(a,b,c)
def make_prompt_from_quoted_string(x):
#print(' got quoted prompt', x)
x_unquoted = x[0][1:-1]
if len(x_unquoted.strip()) == 0:
# print(' b : just an empty string')
return Prompt([Fragment('')])
# print(' b parsing ', c_unquoted)
x_parsed = prompt.parse_string(x_unquoted)
#print(" quoted prompt was parsed to", type(x_parsed),":", x_parsed)
return x_parsed[0]
quoted_prompt = pp.dbl_quoted_string.set_parse_action(make_prompt_from_quoted_string)
quoted_prompt.set_name('quoted_prompt')
blend_terms = pp.delimited_list(quoted_prompt).set_name('blend_terms')
blend_weights = pp.delimited_list(number).set_name('blend_weights')
blend = pp.Group(lparen + pp.Group(blend_terms) + rparen
+ pp.Literal(".blend").suppress()
+ lparen + pp.Group(blend_weights) + rparen).set_name('blend')
blend.set_debug(False)
blend.set_parse_action(lambda x: Blend(prompts=x[0][0], weights=x[0][1]))
conjunction_terms = blend_terms.copy().set_name('conjunction_terms')
conjunction_weights = blend_weights.copy().set_name('conjunction_weights')
conjunction_with_parens_and_quotes = pp.Group(lparen + pp.Group(conjunction_terms) + rparen
+ pp.Literal(".and").suppress()
+ lparen + pp.Optional(pp.Group(conjunction_weights)) + rparen).set_name('conjunction')
def make_conjunction(x):
parts_raw = x[0][0]
weights = x[0][1] if len(x[0])>1 else [1.0]*len(parts_raw)
parts = [part for part in parts_raw]
return Conjunction(parts, weights)
conjunction_with_parens_and_quotes.set_parse_action(make_conjunction)
implicit_conjunction = pp.OneOrMore(blend | prompt)
implicit_conjunction.set_parse_action(lambda x: Conjunction(x))
conjunction = conjunction_with_parens_and_quotes | implicit_conjunction
conjunction.set_debug(False)
# top-level is a conjunction of one or more blends or prompts
return conjunction

View File

@ -820,21 +820,21 @@ class LatentDiffusion(DDPM):
) )
return self.scale_factor * z return self.scale_factor * z
def get_learned_conditioning(self, c): def get_learned_conditioning(self, c, **kwargs):
if self.cond_stage_forward is None: if self.cond_stage_forward is None:
if hasattr(self.cond_stage_model, 'encode') and callable( if hasattr(self.cond_stage_model, 'encode') and callable(
self.cond_stage_model.encode self.cond_stage_model.encode
): ):
c = self.cond_stage_model.encode( c = self.cond_stage_model.encode(
c, embedding_manager=self.embedding_manager c, embedding_manager=self.embedding_manager, **kwargs
) )
if isinstance(c, DiagonalGaussianDistribution): if isinstance(c, DiagonalGaussianDistribution):
c = c.mode() c = c.mode()
else: else:
c = self.cond_stage_model(c) c = self.cond_stage_model(c, **kwargs)
else: else:
assert hasattr(self.cond_stage_model, self.cond_stage_forward) assert hasattr(self.cond_stage_model, self.cond_stage_forward)
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) c = getattr(self.cond_stage_model, self.cond_stage_forward)(c, **kwargs)
return c return c
def meshgrid(self, h, w): def meshgrid(self, h, w):

View File

@ -1,3 +1,5 @@
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from functools import partial from functools import partial
@ -449,11 +451,23 @@ class FrozenCLIPEmbedder(AbstractEncoder):
tokens = batch_encoding['input_ids'].to(self.device) tokens = batch_encoding['input_ids'].to(self.device)
z = self.transformer(input_ids=tokens, **kwargs) z = self.transformer(input_ids=tokens, **kwargs)
return z, tokens if kwargs.get('return_tokens', False):
return z, tokens
else:
return z
def encode(self, text, **kwargs): def encode(self, text, **kwargs):
return self(text, **kwargs) return self(text, **kwargs)
class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
@classmethod
def apply_embedding_weights(self, embeddings: torch.Tensor, per_embedding_weights: list[float], normalize:bool) -> torch.Tensor:
per_embedding_weights = torch.tensor(per_embedding_weights, dtype=embeddings.dtype, device=embeddings.device)
if normalize:
per_embedding_weights = per_embedding_weights / torch.sum(per_embedding_weights)
reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1, 1,))
return torch.sum(embeddings * reshaped_weights, dim=1)
class FrozenCLIPTextEmbedder(nn.Module): class FrozenCLIPTextEmbedder(nn.Module):
""" """

173
tests/test_prompt_parser.py Normal file
View File

@ -0,0 +1,173 @@
import unittest
from ldm.invoke.prompt_parser import PromptParser, Blend, Conjunction, FlattenedPrompt, CrossAttentionControlSubstitute
def parse_prompt(prompt_string):
pp = PromptParser()
#print(f"parsing '{prompt_string}'")
parse_result = pp.parse(prompt_string)
#print(f"-> parsed '{prompt_string}' to {parse_result}")
return parse_result
class PromptParserTestCase(unittest.TestCase):
def test_empty(self):
self.assertEqual(Conjunction([FlattenedPrompt([('', 1)])]), parse_prompt(''))
def test_basic(self):
self.assertEqual(Conjunction([FlattenedPrompt([('fire (flames)', 1)])]), parse_prompt("fire (flames)"))
self.assertEqual(Conjunction([FlattenedPrompt([("fire flames", 1)])]), parse_prompt("fire flames"))
self.assertEqual(Conjunction([FlattenedPrompt([("fire, flames", 1)])]), parse_prompt("fire, flames"))
self.assertEqual(Conjunction([FlattenedPrompt([("fire, flames , fire", 1)])]), parse_prompt("fire, flames , fire"))
def test_attention(self):
self.assertEqual(Conjunction([FlattenedPrompt([('flames', 0.5)])]), parse_prompt("0.5(flames)"))
self.assertEqual(Conjunction([FlattenedPrompt([('fire flames', 0.5)])]), parse_prompt("0.5(fire flames)"))
self.assertEqual(Conjunction([FlattenedPrompt([('flames', 1.1)])]), parse_prompt("+(flames)"))
self.assertEqual(Conjunction([FlattenedPrompt([('flames', 0.9)])]), parse_prompt("-(flames)"))
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1), ('flames', 0.5)])]), parse_prompt("fire 0.5(flames)"))
self.assertEqual(Conjunction([FlattenedPrompt([('flames', pow(1.1, 2))])]), parse_prompt("++(flames)"))
self.assertEqual(Conjunction([FlattenedPrompt([('flames', pow(0.9, 2))])]), parse_prompt("--(flames)"))
self.assertEqual(Conjunction([FlattenedPrompt([('flowers', pow(0.9, 3)), ('flames', pow(1.1, 3))])]), parse_prompt("---(flowers) +++flames"))
self.assertEqual(Conjunction([FlattenedPrompt([('flowers', pow(0.9, 3)), ('flames', pow(1.1, 3))])]), parse_prompt("---(flowers) +++flames"))
self.assertEqual(Conjunction([FlattenedPrompt([('flowers', pow(0.9, 3)), ('flames+', pow(1.1, 3))])]),
parse_prompt("---(flowers) +++flames+"))
self.assertEqual(Conjunction([FlattenedPrompt([('pretty flowers', 1.1)])]),
parse_prompt("+(pretty flowers)"))
self.assertEqual(Conjunction([FlattenedPrompt([('pretty flowers', 1.1), (', the flames are too hot', 1)])]),
parse_prompt("+(pretty flowers), the flames are too hot"))
def test_no_parens_attention_runon(self):
self.assertEqual(Conjunction([FlattenedPrompt([('fire', pow(1.1, 2)), ('flames', 1.0)])]), parse_prompt("++fire flames"))
self.assertEqual(Conjunction([FlattenedPrompt([('fire', pow(0.9, 2)), ('flames', 1.0)])]), parse_prompt("--fire flames"))
self.assertEqual(Conjunction([FlattenedPrompt([('flowers', 1.0), ('fire', pow(1.1, 2)), ('flames', 1.0)])]), parse_prompt("flowers ++fire flames"))
self.assertEqual(Conjunction([FlattenedPrompt([('flowers', 1.0), ('fire', pow(0.9, 2)), ('flames', 1.0)])]), parse_prompt("flowers --fire flames"))
def test_explicit_conjunction(self):
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])]), parse_prompt('("fire", "flames").and(1,1)'))
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])]), parse_prompt('("fire", "flames").and()'))
self.assertEqual(
Conjunction([FlattenedPrompt([('fire flames', 1.0)]), FlattenedPrompt([('mountain man', 1.0)])]), parse_prompt('("fire flames", "mountain man").and()'))
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 2.0)]), FlattenedPrompt([('flames', 0.9)])]), parse_prompt('("2.0(fire)", "-flames").and()'))
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)]),
FlattenedPrompt([('mountain man', 1.0)])]), parse_prompt('("fire", "flames", "mountain man").and()'))
def test_conjunction_weights(self):
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])], weights=[2.0,1.0]), parse_prompt('("fire", "flames").and(2,1)'))
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])], weights=[1.0,2.0]), parse_prompt('("fire", "flames").and(1,2)'))
with self.assertRaises(PromptParser.ParsingException):
parse_prompt('("fire", "flames").and(2)')
parse_prompt('("fire", "flames").and(2,1,2)')
def test_complex_conjunction(self):
self.assertEqual(Conjunction([FlattenedPrompt([("mountain man", 1.0)]), FlattenedPrompt([("a person with a hat", 1.0), ("riding a bicycle", pow(1.1,2))])], weights=[0.5, 0.5]),
parse_prompt("(\"mountain man\", \"a person with a hat ++(riding a bicycle)\").and(0.5, 0.5)"))
def test_badly_formed(self):
def make_untouched_prompt(prompt):
return Conjunction([FlattenedPrompt([(prompt, 1.0)])])
def assert_if_prompt_string_not_untouched(prompt):
self.assertEqual(make_untouched_prompt(prompt), parse_prompt(prompt))
assert_if_prompt_string_not_untouched('a test prompt')
assert_if_prompt_string_not_untouched('a badly (formed test prompt')
assert_if_prompt_string_not_untouched('a badly formed test+ prompt')
assert_if_prompt_string_not_untouched('a badly (formed test+ prompt')
assert_if_prompt_string_not_untouched('a badly (formed test+ )prompt')
assert_if_prompt_string_not_untouched('a badly (formed test+ )prompt')
assert_if_prompt_string_not_untouched('(((a badly (formed test+ )prompt')
assert_if_prompt_string_not_untouched('(a (ba)dly (f)ormed test+ prompt')
self.assertEqual(Conjunction([FlattenedPrompt([('(a (ba)dly (f)ormed test+', 1.0), ('prompt', 1.1)])]),
parse_prompt('(a (ba)dly (f)ormed test+ +prompt'))
self.assertEqual(Conjunction([Blend([FlattenedPrompt([('((a badly (formed test+', 1.0)])], weights=[1.0])]),
parse_prompt('("((a badly (formed test+ ").blend(1.0)'))
def test_blend(self):
self.assertEqual(Conjunction(
[Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('fire flames', 1.0)])], [0.7, 0.3])]),
parse_prompt("(\"fire\", \"fire flames\").blend(0.7, 0.3)")
)
self.assertEqual(Conjunction([Blend(
[FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('fire flames', 1.0)]), FlattenedPrompt([('hi', 1.0)])],
[0.7, 0.3, 1.0])]),
parse_prompt("(\"fire\", \"fire flames\", \"hi\").blend(0.7, 0.3, 1.0)")
)
self.assertEqual(Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]),
FlattenedPrompt([('fire flames', 1.0), ('hot', pow(1.1, 2))]),
FlattenedPrompt([('hi', 1.0)])],
weights=[0.7, 0.3, 1.0])]),
parse_prompt("(\"fire\", \"fire flames ++(hot)\", \"hi\").blend(0.7, 0.3, 1.0)")
)
# blend a single entry is not a failure
self.assertEqual(Conjunction([Blend([FlattenedPrompt([('fire', 1.0)])], [0.7])]),
parse_prompt("(\"fire\").blend(0.7)")
)
# blend with empty
self.assertEqual(
Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]),
parse_prompt("(\"fire\", \"\").blend(0.7, 1)")
)
self.assertEqual(
Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]),
parse_prompt("(\"fire\", \" \").blend(0.7, 1)")
)
self.assertEqual(
Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]),
parse_prompt("(\"fire\", \" \").blend(0.7, 1)")
)
self.assertEqual(
Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([(',', 1.0)])], [0.7, 1.0])]),
parse_prompt("(\"fire\", \" , \").blend(0.7, 1)")
)
def test_nested(self):
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0), ('flames', 2.0), ('trees', 3.0)])]),
parse_prompt('fire 2.0(flames 1.5(trees))'))
self.assertEqual(Conjunction([Blend(prompts=[FlattenedPrompt([('fire', 1.0), ('flames', 1.2100000000000002)]),
FlattenedPrompt([('mountain', 1.0), ('man', 2.0)])],
weights=[1.0, 1.0])]),
parse_prompt('("fire ++(flames)", "mountain 2(man)").blend(1,1)'))
def test_cross_attention_control(self):
fire_flames_to_trees = Conjunction([FlattenedPrompt([('fire', 1.0), \
CrossAttentionControlSubstitute('flames', 'trees')])])
self.assertEqual(fire_flames_to_trees, parse_prompt('fire "flames".swap(trees)'))
self.assertEqual(fire_flames_to_trees, parse_prompt('fire (flames).swap(trees)'))
self.assertEqual(fire_flames_to_trees, parse_prompt('fire ("flames").swap(trees)'))
self.assertEqual(fire_flames_to_trees, parse_prompt('fire "flames".swap("trees")'))
self.assertEqual(fire_flames_to_trees, parse_prompt('fire (flames).swap("trees")'))
self.assertEqual(fire_flames_to_trees, parse_prompt('fire ("flames").swap("trees")'))
fire_flames_to_trees_and_houses = Conjunction([FlattenedPrompt([('fire', 1.0), \
CrossAttentionControlSubstitute('flames', 'trees and houses')])])
self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire ("flames").swap("trees and houses")'))
self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire (flames).swap("trees and houses")'))
self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire "flames".swap("trees and houses")'))
trees_and_houses_to_flames = Conjunction([FlattenedPrompt([('fire', 1.0), \
CrossAttentionControlSubstitute('trees and houses', 'flames')])])
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire ("trees and houses").swap("flames")'))
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire (trees and houses).swap("flames")'))
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire "trees and houses".swap("flames")'))
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire ("trees and houses").swap(flames)'))
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire (trees and houses).swap(flames)'))
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire "trees and houses".swap(flames)'))
flames_to_trees_fire = Conjunction([FlattenedPrompt([
CrossAttentionControlSubstitute('flames', 'trees'),
(', fire', 1.0)])])
self.assertEqual(flames_to_trees_fire, parse_prompt('"flames".swap(trees), fire'))
self.assertEqual(flames_to_trees_fire, parse_prompt('(flames).swap(trees), fire '))
self.assertEqual(flames_to_trees_fire, parse_prompt('("flames").swap(trees), fire '))
self.assertEqual(flames_to_trees_fire, parse_prompt('"flames".swap("trees"), fire'))
self.assertEqual(flames_to_trees_fire, parse_prompt('(flames).swap("trees"), fire'))
self.assertEqual(flames_to_trees_fire, parse_prompt('("flames").swap("trees"), fire'))
if __name__ == '__main__':
unittest.main()