From 3f13dd3ae8bbdf06067adf891e5caa925fde394a Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Thu, 20 Oct 2022 21:05:36 +0200 Subject: [PATCH] prompt parsing is now much more robust --- ldm/invoke/prompt_parser.py | 151 ++++++++++++++++++------------- tests/test_prompt_parser.py | 173 +++++++++++++++++++++++++++--------- 2 files changed, 220 insertions(+), 104 deletions(-) diff --git a/ldm/invoke/prompt_parser.py b/ldm/invoke/prompt_parser.py index 398a596c7e..d576d069aa 100644 --- a/ldm/invoke/prompt_parser.py +++ b/ldm/invoke/prompt_parser.py @@ -9,8 +9,8 @@ class Prompt(): def __init__(self, parts: list): for c in parts: - if type(c) is not Attention and not issubclass(type(c), BaseFragment): - raise PromptParser.ParsingException(f"Prompt cannot contain {type(c)}, only {BaseFragment.__subclasses__()} are allowed") + if type(c) is not Attention and not issubclass(type(c), BaseFragment) and type(c) is not pp.ParseResults: + raise PromptParser.ParsingException(f"Prompt cannot contain {type(c).__name__} {c}, only {BaseFragment.__subclasses__()} are allowed") self.children = parts def __repr__(self): return f"Prompt:{self.children}" @@ -48,6 +48,9 @@ class BaseFragment: class Fragment(BaseFragment): def __init__(self, text: str, weight: float=1): assert(type(text) is str) + if '\\"' in text or '\\(' in text or '\\)' in text: + #print("Fragment converting escaped \( \) \\\" into ( ) \"") + text = text.replace('\\(', '(').replace('\\)', ')').replace('\\"', '"') self.text = text self.weight = float(weight) @@ -152,8 +155,10 @@ class PromptParser(): def parse(self, prompt: str) -> Conjunction: ''' + This parser is *very* forgiving. If it cannot parse syntax, it will return strings as-is to be passed on to the + diffusion. :param prompt: The prompt string to parse - :return: a tuple + :return: a Conjunction representing the parsed results. ''' #print(f"!!parsing '{prompt}'") @@ -169,7 +174,7 @@ class PromptParser(): def flatten(self, root: Conjunction): - print("flattening", root) + #print("flattening", root) def fuse_fragments(items): # print("fusing fragments in ", items) @@ -196,13 +201,13 @@ class PromptParser(): #print(prefix + "flattening", node, "...") if type(node) is pp.ParseResults: for x in node: - results = flatten_internal(x, weight_scale, results, prefix+'pr') + results = flatten_internal(x, weight_scale, results, prefix+' pr ') #print(prefix, " ParseResults expanded, results is now", results) elif type(node) is Attention: # if node.weight < 1: # todo: inject a blend when flattening attention with weight <1" - for c in node.children: - results = flatten_internal(c, weight_scale * node.weight, results, prefix + ' ') + for index,c in enumerate(node.children): + results = flatten_internal(c, weight_scale * node.weight, results, prefix + f" att{index} ") elif type(node) is Fragment: results += [Fragment(node.text, node.weight*weight_scale)] elif type(node) is CrossAttentionControlSubstitute: @@ -225,7 +230,7 @@ class PromptParser(): #print(prefix + "after flattening Prompt, results is", results) else: raise PromptParser.ParsingException(f"unhandled node type {type(node)} when flattening {node}") - print(prefix + "-> after flattening", type(node).__name__, "results is", results) + #print(prefix + "-> after flattening", type(node).__name__, "results is", results) return results @@ -246,6 +251,7 @@ class PromptParser(): # accepts int or float notation, always maps to float number = pyparsing.pyparsing_common.real | pp.Word(pp.nums).set_parse_action(pp.token_map(float)) SPACE_CHARS = string.whitespace + greedy_word = pp.Word(pp.printables, exclude_chars=string.whitespace).set_name('greedy_word') attention = pp.Forward() @@ -254,7 +260,7 @@ class PromptParser(): if type(x) is str: return Fragment(x) elif type(x) is pp.ParseResults or type(x) is list: - #print(f'converting {x} to Fragment') + #print(f'converting {type(x).__name__} to Fragment') return Fragment(' '.join([s for s in x])) else: raise PromptParser.ParsingException("Cannot make fragment from " + str(x)) @@ -264,52 +270,72 @@ class PromptParser(): parenthesized_fragment = pp.Forward() def parse_fragment_str(x): - print("parsing", x) + #print("parsing fragment string", x) if len(x[0].strip()) == 0: return Fragment('') - fragment_parser = pp.Group(pp.OneOrMore(attention | pp.Word(pp.printables, exclude_chars=string.whitespace).set_parse_action(make_fragment))) + fragment_parser = pp.Group(pp.OneOrMore(attention | (greedy_word.set_parse_action(make_fragment)))) fragment_parser.set_name('word_or_attention') result = fragment_parser.parse_string(x[0]) #result = (pp.OneOrMore(attention | unquoted_fragment) + pp.StringEnd()).parse_string(x[0]) - print("parsed to", result) + #print("parsed to", result) return result quoted_fragment << pp.QuotedString(quote_char='"', esc_char='\\') - quoted_fragment.set_parse_action(make_fragment).set_name('quoted_fragment') + quoted_fragment.set_parse_action(parse_fragment_str).set_name('quoted_fragment') + + self_unescaping_escaped_quote = pp.Literal('\\"').set_parse_action(lambda x: '"') + self_unescaping_escaped_lparen = pp.Literal('\\(').set_parse_action(lambda x: '(') + self_unescaping_escaped_rparen = pp.Literal('\\)').set_parse_action(lambda x: ')') unquoted_fragment << pp.Combine(pp.OneOrMore( - pp.Literal('\\"').set_debug(False) | - pp.Literal('\\').set_debug(False) | - pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"') + self_unescaping_escaped_rparen | self_unescaping_escaped_lparen | self_unescaping_escaped_quote | + pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"()') )) unquoted_fragment.set_parse_action(make_fragment).set_name('unquoted_fragment') - parenthesized_fragment << pp.Or([ - (lparen + quoted_fragment.set_parse_action(parse_fragment_str).set_debug(True) + rparen).set_name('-quoted_paren_internal').set_debug(True), - (lparen + rparen).set_parse_action(lambda x: make_fragment('')).set_name('-()').set_debug(True), + parenthesized_fragment << pp.MatchFirst([ + (lparen + quoted_fragment.copy().set_parse_action(parse_fragment_str).set_debug(False) + rparen).set_name('-quoted_paren_internal').set_debug(False), + (lparen + rparen).set_parse_action(lambda x: make_fragment('')).set_name('-()').set_debug(False), (lparen + pp.Combine(pp.OneOrMore( - pp.Literal('\\)').set_debug(False) | - pp.Literal('\\').set_debug(False) | - pp.Word(pp.printables, exclude_chars=string.whitespace + '\\)') | + pp.Literal('\\"').set_debug(False).set_parse_action(lambda x: '"') | + pp.Literal('\\(').set_debug(False).set_parse_action(lambda x: '(') | + pp.Literal('\\)').set_debug(False).set_parse_action(lambda x: ')') | + pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"()') | pp.Word(string.whitespace) - )).set_name('--combined').set_parse_action(parse_fragment_str).set_debug(True) + rparen)]).set_name('-unquoted_paren_internal').set_debug(True) - parenthesized_fragment.set_name('parenthesized_fragment').set_debug(True) + )).set_name('--combined').set_parse_action(parse_fragment_str).set_debug(False) + rparen)]).set_name('-unquoted_paren_internal').set_debug(False) + parenthesized_fragment.set_name('parenthesized_fragment').set_debug(False) + debug_attention = False # attention control of the form +(phrase) / -(phrase) / (phrase) # phrase can be multiple words, can have multiple +/- signs to increase the effect or type a floating point or integer weight attention_head = (number | pp.Word('+') | pp.Word('-'))\ .set_name("attention_head")\ .set_debug(False) - fragment_inside_attention = pp.CharsNotIn(SPACE_CHARS+'()')\ - .set_parse_action(make_fragment)\ - .set_name("fragment_inside_attention")\ - .set_debug(False) + word_inside_attention = pp.Combine(pp.OneOrMore( + pp.Literal('\\)') | pp.Literal('\\(') | pp.Literal('\\"') | + pp.Word(pp.printables, exclude_chars=string.whitespace + '\\()"') + )).set_name('word_inside_attention') attention_with_parens = pp.Forward() - attention_with_parens_body = pp.nested_expr(content=pp.delimited_list((attention_with_parens | fragment_inside_attention), delim=SPACE_CHARS)) + attention_with_parens_delimited_list = pp.delimited_list(pp.Or([ + quoted_fragment.copy().set_debug(debug_attention), + attention.copy().set_debug(debug_attention), + word_inside_attention.set_debug(debug_attention)]).set_name('delim_inner').set_debug(debug_attention), + delim=string.whitespace) + # have to disable ignore_expr here to prevent pyparsing from stripping off quote marks + attention_with_parens_body = pp.nested_expr(content=attention_with_parens_delimited_list, + ignore_expr=None#((pp.Literal("\\(") | pp.Literal('\\)'))) + ) + attention_with_parens_body.set_debug(debug_attention) attention_with_parens << (attention_head + attention_with_parens_body) + attention_with_parens.set_name('attention_with_parens').set_debug(debug_attention) + + attention_without_parens = (pp.Word('+') | pp.Word('-')) + (quoted_fragment | word_inside_attention) + attention_without_parens.set_name('attention_without_parens').set_debug(debug_attention) + + attention << (attention_with_parens | attention_without_parens) def make_attention(x): - # print("making Attention from parsing with args", x0, x1) + #print("making Attention from", x) weight = 1 # number(str) if type(x[0]) is float or type(x[0]) is int: @@ -318,26 +344,17 @@ class PromptParser(): elif type(x[0]) is str: base = self.attention_plus_base if x[0][0] == '+' else self.attention_minus_base weight = pow(base, len(x[0])) - # print("Making attention with children of type", [str(type(x)) for x in x1]) - return Attention(weight=weight, children=x[1]) + if type(x[1]) is list or type(x[1]) is pp.ParseResults: + return Attention(weight=weight, children=[(Fragment(x) if type(x) is str else x) for x in x[1]]) + elif type(x[1]) is str: + return Attention(weight=weight, children=[Fragment(x[1])]) + elif type(x[1]) is Fragment: + return Attention(weight=weight, children=[x[1]]) + raise PromptParser.ParsingException(f"Don't know how to make attention with children {x[1]}") - attention_with_parens.set_parse_action(make_attention)\ - .set_name("attention_with_parens")\ - .set_debug(False) - - # attention control of the form ++word --word (no parens) - attention_without_parens = ( - (pp.Word('+') | pp.Word('-')) + - pp.CharsNotIn(SPACE_CHARS+'()').set_parse_action(lambda x: [[make_fragment(x)]]) - )\ - .set_name("attention_without_parens")\ - .set_debug(False) + attention_with_parens.set_parse_action(make_attention) attention_without_parens.set_parse_action(make_attention) - attention << (attention_with_parens | attention_without_parens)\ - .set_name("attention")\ - .set_debug(False) - # cross-attention control empty_string = ((lparen + rparen) | pp.Literal('""').suppress() | @@ -345,26 +362,38 @@ class PromptParser(): ).set_parse_action(lambda x: Fragment("")) empty_string.set_name('empty_string') - original_fragment = empty_string | quoted_fragment | parenthesized_fragment | unquoted_fragment + + # cross attention control + debug_cross_attention_control = False + original_fragment = pp.Or([empty_string.set_debug(debug_cross_attention_control), + quoted_fragment.set_debug(debug_cross_attention_control), + parenthesized_fragment.set_debug(debug_cross_attention_control), + unquoted_fragment.set_debug(debug_cross_attention_control)]) edited_fragment = parenthesized_fragment cross_attention_substitute = original_fragment + pp.Literal(".swap").suppress() + edited_fragment - cross_attention_substitute.set_name('cross_attention_substitute').set_debug(True) + original_fragment.set_name('original_fragment').set_debug(debug_cross_attention_control) + edited_fragment.set_name('edited_fragment').set_debug(debug_cross_attention_control) + cross_attention_substitute.set_name('cross_attention_substitute').set_debug(debug_cross_attention_control) def make_cross_attention_substitute(x): - print("making cacs for", x) + #print("making cacs for", x) cacs = CrossAttentionControlSubstitute(x[0], x[1]) - print("made", cacs) + #print("made", cacs) return cacs cross_attention_substitute.set_parse_action(make_cross_attention_substitute) + + # simple fragments of text - prompt_part = ( - cross_attention_substitute - | attention - | quoted_fragment - | unquoted_fragment - ) + # use Or to match the longest + prompt_part = pp.Or([ + cross_attention_substitute, + attention, + quoted_fragment, + unquoted_fragment, + lparen + unquoted_fragment + rparen # matches case where user has +(term) and just deletes the + + ]) prompt_part.set_debug(False) prompt_part.set_name("prompt_part") @@ -373,8 +402,10 @@ class PromptParser(): (quotes + pp.ZeroOrMore(pp.Word(string.whitespace)) + quotes)).set_debug(False).set_name('empty') # root prompt definition - prompt = (pp.Group(pp.OneOrMore(prompt_part) | empty) + pp.StringEnd()) \ - .set_parse_action(lambda x: Prompt(x[0])) + prompt = ((pp.OneOrMore(prompt_part) | empty) + pp.StringEnd()) \ + .set_parse_action(lambda x: Prompt(x)) + + # weighted blend of prompts # ("promptA", "promptB").blend(a, b) where "promptA" and "promptB" are valid prompts and a and b are float or @@ -418,7 +449,7 @@ class PromptParser(): return Conjunction(parts, weights) conjunction_with_parens_and_quotes.set_parse_action(make_conjunction) - implicit_conjunction = pp.OneOrMore(blend | prompt) + implicit_conjunction = pp.OneOrMore(blend | prompt).set_name('implicit_conjunction') implicit_conjunction.set_parse_action(lambda x: Conjunction(x)) conjunction = conjunction_with_parens_and_quotes | implicit_conjunction diff --git a/tests/test_prompt_parser.py b/tests/test_prompt_parser.py index 38a24ca529..d053253eb6 100644 --- a/tests/test_prompt_parser.py +++ b/tests/test_prompt_parser.py @@ -1,5 +1,7 @@ import unittest +import pyparsing + from ldm.invoke.prompt_parser import PromptParser, Blend, Conjunction, FlattenedPrompt, CrossAttentionControlSubstitute, \ Fragment @@ -11,39 +13,48 @@ def parse_prompt(prompt_string): #print(f"-> parsed '{prompt_string}' to {parse_result}") return parse_result +def make_basic_conjunction(strings: list[str]): + fragments = [Fragment(x) for x in strings] + return Conjunction([FlattenedPrompt(fragments)]) + +def make_weighted_conjunction(weighted_strings: list[tuple[str,float]]): + fragments = [Fragment(x, w) for x,w in weighted_strings] + return Conjunction([FlattenedPrompt(fragments)]) + + class PromptParserTestCase(unittest.TestCase): def test_empty(self): - self.assertEqual(Conjunction([FlattenedPrompt([('', 1)])]), parse_prompt('')) + self.assertEqual(make_weighted_conjunction([('', 1)]), parse_prompt('')) def test_basic(self): - self.assertEqual(Conjunction([FlattenedPrompt([('fire (flames)', 1)])]), parse_prompt("fire (flames)")) - self.assertEqual(Conjunction([FlattenedPrompt([("fire flames", 1)])]), parse_prompt("fire flames")) - self.assertEqual(Conjunction([FlattenedPrompt([("fire, flames", 1)])]), parse_prompt("fire, flames")) - self.assertEqual(Conjunction([FlattenedPrompt([("fire, flames , fire", 1)])]), parse_prompt("fire, flames , fire")) + self.assertEqual(make_weighted_conjunction([('fire flames', 1)]), parse_prompt("fire (flames)")) + self.assertEqual(make_weighted_conjunction([("fire flames", 1)]), parse_prompt("fire flames")) + self.assertEqual(make_weighted_conjunction([("fire, flames", 1)]), parse_prompt("fire, flames")) + self.assertEqual(make_weighted_conjunction([("fire, flames , fire", 1)]), parse_prompt("fire, flames , fire")) def test_attention(self): - self.assertEqual(Conjunction([FlattenedPrompt([('flames', 0.5)])]), parse_prompt("0.5(flames)")) - self.assertEqual(Conjunction([FlattenedPrompt([('fire flames', 0.5)])]), parse_prompt("0.5(fire flames)")) - self.assertEqual(Conjunction([FlattenedPrompt([('flames', 1.1)])]), parse_prompt("+(flames)")) - self.assertEqual(Conjunction([FlattenedPrompt([('flames', 0.9)])]), parse_prompt("-(flames)")) - self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1), ('flames', 0.5)])]), parse_prompt("fire 0.5(flames)")) - self.assertEqual(Conjunction([FlattenedPrompt([('flames', pow(1.1, 2))])]), parse_prompt("++(flames)")) - self.assertEqual(Conjunction([FlattenedPrompt([('flames', pow(0.9, 2))])]), parse_prompt("--(flames)")) - self.assertEqual(Conjunction([FlattenedPrompt([('flowers', pow(0.9, 3)), ('flames', pow(1.1, 3))])]), parse_prompt("---(flowers) +++flames")) - self.assertEqual(Conjunction([FlattenedPrompt([('flowers', pow(0.9, 3)), ('flames', pow(1.1, 3))])]), parse_prompt("---(flowers) +++flames")) - self.assertEqual(Conjunction([FlattenedPrompt([('flowers', pow(0.9, 3)), ('flames+', pow(1.1, 3))])]), + self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("0.5(flames)")) + self.assertEqual(make_weighted_conjunction([('fire flames', 0.5)]), parse_prompt("0.5(fire flames)")) + self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("+(flames)")) + self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("-(flames)")) + self.assertEqual(make_weighted_conjunction([('fire', 1), ('flames', 0.5)]), parse_prompt("fire 0.5(flames)")) + self.assertEqual(make_weighted_conjunction([('flames', pow(1.1, 2))]), parse_prompt("++(flames)")) + self.assertEqual(make_weighted_conjunction([('flames', pow(0.9, 2))]), parse_prompt("--(flames)")) + self.assertEqual(make_weighted_conjunction([('flowers', pow(0.9, 3)), ('flames', pow(1.1, 3))]), parse_prompt("---(flowers) +++flames")) + self.assertEqual(make_weighted_conjunction([('flowers', pow(0.9, 3)), ('flames', pow(1.1, 3))]), parse_prompt("---(flowers) +++flames")) + self.assertEqual(make_weighted_conjunction([('flowers', pow(0.9, 3)), ('flames+', pow(1.1, 3))]), parse_prompt("---(flowers) +++flames+")) - self.assertEqual(Conjunction([FlattenedPrompt([('pretty flowers', 1.1)])]), + self.assertEqual(make_weighted_conjunction([('pretty flowers', 1.1)]), parse_prompt("+(pretty flowers)")) - self.assertEqual(Conjunction([FlattenedPrompt([('pretty flowers', 1.1), (', the flames are too hot', 1)])]), + self.assertEqual(make_weighted_conjunction([('pretty flowers', 1.1), (', the flames are too hot', 1)]), parse_prompt("+(pretty flowers), the flames are too hot")) def test_no_parens_attention_runon(self): - self.assertEqual(Conjunction([FlattenedPrompt([('fire', pow(1.1, 2)), ('flames', 1.0)])]), parse_prompt("++fire flames")) - self.assertEqual(Conjunction([FlattenedPrompt([('fire', pow(0.9, 2)), ('flames', 1.0)])]), parse_prompt("--fire flames")) - self.assertEqual(Conjunction([FlattenedPrompt([('flowers', 1.0), ('fire', pow(1.1, 2)), ('flames', 1.0)])]), parse_prompt("flowers ++fire flames")) - self.assertEqual(Conjunction([FlattenedPrompt([('flowers', 1.0), ('fire', pow(0.9, 2)), ('flames', 1.0)])]), parse_prompt("flowers --fire flames")) + self.assertEqual(make_weighted_conjunction([('fire', pow(1.1, 2)), ('flames', 1.0)]), parse_prompt("++fire flames")) + self.assertEqual(make_weighted_conjunction([('fire', pow(0.9, 2)), ('flames', 1.0)]), parse_prompt("--fire flames")) + self.assertEqual(make_weighted_conjunction([('flowers', 1.0), ('fire', pow(1.1, 2)), ('flames', 1.0)]), parse_prompt("flowers ++fire flames")) + self.assertEqual(make_weighted_conjunction([('flowers', 1.0), ('fire', pow(0.9, 2)), ('flames', 1.0)]), parse_prompt("flowers --fire flames")) def test_explicit_conjunction(self): @@ -75,17 +86,27 @@ class PromptParserTestCase(unittest.TestCase): self.assertEqual(make_untouched_prompt(prompt), parse_prompt(prompt)) assert_if_prompt_string_not_untouched('a test prompt') - assert_if_prompt_string_not_untouched('a badly (formed test prompt') assert_if_prompt_string_not_untouched('a badly formed test+ prompt') - assert_if_prompt_string_not_untouched('a badly (formed test+ prompt') - assert_if_prompt_string_not_untouched('a badly (formed test+ )prompt') - assert_if_prompt_string_not_untouched('a badly (formed test+ )prompt') - assert_if_prompt_string_not_untouched('(((a badly (formed test+ )prompt') - assert_if_prompt_string_not_untouched('(a (ba)dly (f)ormed test+ prompt') - self.assertEqual(Conjunction([FlattenedPrompt([('(a (ba)dly (f)ormed test+', 1.0), ('prompt', 1.1)])]), - parse_prompt('(a (ba)dly (f)ormed test+ +prompt')) - self.assertEqual(Conjunction([Blend([FlattenedPrompt([('((a badly (formed test+', 1.0)])], weights=[1.0])]), - parse_prompt('("((a badly (formed test+ ").blend(1.0)')) + with self.assertRaises(pyparsing.ParseException): + parse_prompt('a badly (formed test prompt') + #with self.assertRaises(pyparsing.ParseException): + with self.assertRaises(pyparsing.ParseException): + parse_prompt('a badly (formed test+ prompt') + with self.assertRaises(pyparsing.ParseException): + parse_prompt('a badly (formed test+ )prompt') + with self.assertRaises(pyparsing.ParseException): + parse_prompt('a badly (formed test+ )prompt') + with self.assertRaises(pyparsing.ParseException): + parse_prompt('(((a badly (formed test+ )prompt') + with self.assertRaises(pyparsing.ParseException): + parse_prompt('(a (ba)dly (f)ormed test+ prompt') + with self.assertRaises(pyparsing.ParseException): + parse_prompt('(a (ba)dly (f)ormed test+ +prompt') + with self.assertRaises(pyparsing.ParseException): + parse_prompt('("((a badly (formed test+ ").blend(1.0)') + with self.assertRaises(pyparsing.ParseException): + parse_prompt('mountain (\\"man").swap("monkey")') + def test_blend(self): self.assertEqual(Conjunction( @@ -127,7 +148,7 @@ class PromptParserTestCase(unittest.TestCase): def test_nested(self): - self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0), ('flames', 2.0), ('trees', 3.0)])]), + self.assertEqual(make_weighted_conjunction([('fire', 1.0), ('flames', 2.0), ('trees', 3.0)]), parse_prompt('fire 2.0(flames 1.5(trees))')) self.assertEqual(Conjunction([Blend(prompts=[FlattenedPrompt([('fire', 1.0), ('flames', 1.2100000000000002)]), FlattenedPrompt([('mountain', 1.0), ('man', 2.0)])], @@ -202,20 +223,84 @@ class PromptParserTestCase(unittest.TestCase): self.assertEqual(flames_to_trees_fire, parse_prompt('"0.5(fire 0.5(flames))".swap("0.7(trees) houses"), 2.0(fire)')) - def make_basic_conjunction(self, strings: list[str]): - fragments = [Fragment(x) for x in strings] - return Conjunction([FlattenedPrompt(fragments)]) - - def make_weighted_conjunction(self, weighted_strings: list[tuple[str,float]]): - fragments = [Fragment(x, w) for x,w in weighted_strings] - return Conjunction([FlattenedPrompt(fragments)]) - def test_escaping(self): - self.assertEqual(self.make_basic_conjunction(['mountain \(man\)']),parse_prompt('mountain \(man\)')) - self.assertEqual(self.make_basic_conjunction(['mountain (\(man)\)']),parse_prompt('mountain (\(man)\)')) - self.assertEqual(self.make_basic_conjunction(['mountain (\(man\))']),parse_prompt('mountain (\(man\))')) - #self.assertEqual(self.make_weighted_conjunction([('mountain', 1), ('\(man\)', 1.1)]),parse_prompt('mountain +(\(man\))')) + + # make sure ", ( and ) can be escaped + + self.assertEqual(make_basic_conjunction(['mountain (man)']),parse_prompt('mountain \(man\)')) + self.assertEqual(make_basic_conjunction(['mountain (man )']),parse_prompt('mountain (\(man)\)')) + self.assertEqual(make_basic_conjunction(['mountain (man)']),parse_prompt('mountain (\(man\))')) + self.assertEqual(make_weighted_conjunction([('mountain', 1), ('(man)', 1.1)]), parse_prompt('mountain +(\(man\))')) + self.assertEqual(make_weighted_conjunction([('mountain', 1), ('(man)', 1.1)]), parse_prompt('"mountain" +(\(man\))')) + self.assertEqual(make_weighted_conjunction([('"mountain"', 1), ('(man)', 1.1)]), parse_prompt('\\"mountain\\" +(\(man\))')) + # same weights for each are combined into one + self.assertEqual(make_weighted_conjunction([('"mountain" (man)', 1.1)]), parse_prompt('+(\\"mountain\\") +(\(man\))')) + self.assertEqual(make_weighted_conjunction([('"mountain"', 1.1), ('(man)', 0.9)]), parse_prompt('+(\\"mountain\\") -(\(man\))')) + + self.assertEqual(make_weighted_conjunction([('mountain', 1), ('\(man\)', 1.1)]),parse_prompt('mountain 1.1(\(man\))')) + self.assertEqual(make_weighted_conjunction([('mountain', 1), ('\(man\)', 1.1)]),parse_prompt('"mountain" 1.1(\(man\))')) + self.assertEqual(make_weighted_conjunction([('"mountain"', 1), ('\(man\)', 1.1)]),parse_prompt('\\"mountain\\" 1.1(\(man\))')) + # same weights for each are combined into one + self.assertEqual(make_weighted_conjunction([('\\"mountain\\" \(man\)', 1.1)]),parse_prompt('+(\\"mountain\\") 1.1(\(man\))')) + self.assertEqual(make_weighted_conjunction([('\\"mountain\\"', 1.1), ('\(man\)', 0.9)]),parse_prompt('1.1(\\"mountain\\") 0.9(\(man\))')) + + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain', 1.1), ('\(man\)', 1.1*1.1)]),parse_prompt('hairy +(mountain +(\(man\)))')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('\(man\)', 1.1*1.1), ('mountain', 1.1)]),parse_prompt('hairy +(1.1(\(man\)) "mountain")')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain', 1.1), ('\(man\)', 1.1*1.1)]),parse_prompt('hairy +("mountain" 1.1(\(man\)) )')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, man', 1.1)]),parse_prompt('hairy +("mountain, man")')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, man with a', 1.1), ('beard', 1.1*1.1)]), parse_prompt('hairy +("mountain, man" with a +beard)')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, man with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy +("mountain, man" with a 2.0(beard))')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man\" with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy +("mountain, \\"man\\"" with a 2.0(beard))')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, m\"an\" with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy +("mountain, m\\"an\\"" with a 2.0(beard))')) + + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man (with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy +("mountain, \\\"man\" \(with a 2.0(beard))')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man w(ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy +("mountain, \\\"man\" w\(ith a 2.0(beard))')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man with( a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy +("mountain, \\\"man\" with\( a 2.0(beard))')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man )with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy +("mountain, \\\"man\" \)with a 2.0(beard))')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy +("mountain, \\\"man\" w\)ith a 2.0(beard))')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man with) a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy +("mountain, \\\"man\" with\) a 2.0(beard))')) + self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mou)ntain, \"man (wit(h a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy +("mou\)ntain, \\\"man\" \(wit\(h a 2.0(beard))')) + self.assertEqual(make_weighted_conjunction([('hai(ry', 1), ('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hai\(ry +("mountain, \\\"man\" w\)ith a 2.0(beard))')) + self.assertEqual(make_weighted_conjunction([('hairy((', 1), ('mountain, \"man with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy\(\( +("mountain, \\\"man\" with a 2.0(beard))')) + + self.assertEqual(make_weighted_conjunction([('mountain, \"man (with a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('+("mountain, \\\"man\" \(with a 2.0(beard)) hairy')) + self.assertEqual(make_weighted_conjunction([('mountain, \"man w(ith a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('+("mountain, \\\"man\" w\(ith a 2.0(beard))hairy')) + self.assertEqual(make_weighted_conjunction([('mountain, \"man with( a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('+("mountain, \\\"man\" with\( a 2.0(beard)) hairy')) + self.assertEqual(make_weighted_conjunction([('mountain, \"man )with a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('+("mountain, \\\"man\" \)with a 2.0(beard)) hairy')) + self.assertEqual(make_weighted_conjunction([('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('+("mountain, \\\"man\" w\)ith a 2.0(beard)) hairy')) + self.assertEqual(make_weighted_conjunction([('mountain, \"man with) a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt(' +("mountain, \\\"man\" with\) a 2.0(beard)) hairy')) + self.assertEqual(make_weighted_conjunction([('mou)ntain, \"man (wit(h a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('+("mou\)ntain, \\\"man\" \(wit\(h a 2.0(beard)) hairy')) + self.assertEqual(make_weighted_conjunction([('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0), ('hai(ry', 1)]), parse_prompt('+("mountain, \\\"man\" w\)ith a 2.0(beard)) hai\(ry ')) + self.assertEqual(make_weighted_conjunction([('mountain, \"man with a', 1.1), ('beard', 1.1*2.0), ('hairy((', 1)]), parse_prompt('+("mountain, \\\"man\" with a 2.0(beard)) hairy\(\( ')) + + def test_cross_attention_escaping(self): + + self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('monkey', 1)])])]), + parse_prompt('mountain (man).swap(monkey)')) + self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('m(onkey', 1)])])]), + parse_prompt('mountain (man).swap(m\(onkey)')) + self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('m(an', 1)], [Fragment('m(onkey', 1)])])]), + parse_prompt('mountain (m\(an).swap(m\(onkey)')) + self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('(((', 1)], [Fragment('m(on))key', 1)])])]), + parse_prompt('mountain (\(\(\().swap(m\(on\)\)key)')) + + self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('monkey', 1)])])]), + parse_prompt('mountain ("man").swap(monkey)')) + self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('monkey', 1)])])]), + parse_prompt('mountain ("man").swap("monkey")')) + self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('"man', 1)], [Fragment('monkey', 1)])])]), + parse_prompt('mountain (\\"man).swap("monkey")')) + self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('m(onkey', 1)])])]), + parse_prompt('mountain (man).swap(m\(onkey)')) + self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('m(an', 1)], [Fragment('m(onkey', 1)])])]), + parse_prompt('mountain (m\(an).swap(m\(onkey)')) + self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('(((', 1)], [Fragment('m(on))key', 1)])])]), + parse_prompt('mountain (\(\(\().swap(m\(on\)\)key)')) + + def test_single(self): + self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('"man', 1)], [Fragment('monkey', 1)])])]), + parse_prompt('mountain (\\"man).swap("monkey")')) if __name__ == '__main__':