move attention weighting operations to postfix

This commit is contained in:
Damian at mba
2022-10-23 23:01:53 +02:00
parent f7cd98c238
commit b0eb864a25
2 changed files with 147 additions and 117 deletions

View File

@ -34,27 +34,28 @@ class PromptParserTestCase(unittest.TestCase):
self.assertEqual(make_weighted_conjunction([("fire, flames , fire", 1)]), parse_prompt("fire, flames , fire"))
def test_attention(self):
self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("0.5(flames)"))
self.assertEqual(make_weighted_conjunction([('fire flames', 0.5)]), parse_prompt("0.5(fire flames)"))
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("+(flames)"))
self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("-(flames)"))
self.assertEqual(make_weighted_conjunction([('fire', 1), ('flames', 0.5)]), parse_prompt("fire 0.5(flames)"))
self.assertEqual(make_weighted_conjunction([('flames', pow(1.1, 2))]), parse_prompt("++(flames)"))
self.assertEqual(make_weighted_conjunction([('flames', pow(0.9, 2))]), parse_prompt("--(flames)"))
self.assertEqual(make_weighted_conjunction([('flowers', pow(0.9, 3)), ('flames', pow(1.1, 3))]), parse_prompt("---(flowers) +++flames"))
self.assertEqual(make_weighted_conjunction([('flowers', pow(0.9, 3)), ('flames', pow(1.1, 3))]), parse_prompt("---(flowers) +++flames"))
self.assertEqual(make_weighted_conjunction([('flowers', pow(0.9, 3)), ('flames+', pow(1.1, 3))]),
parse_prompt("---(flowers) +++flames+"))
self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("(flames)0.5"))
self.assertEqual(make_weighted_conjunction([('fire flames', 0.5)]), parse_prompt("(fire flames)0.5"))
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("(flames)+"))
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("flames+"))
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("\"flames\"+"))
self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("(flames)-"))
self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("flames-"))
self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("\"flames\"-"))
self.assertEqual(make_weighted_conjunction([('fire', 1), ('flames', 0.5)]), parse_prompt("fire (flames)0.5"))
self.assertEqual(make_weighted_conjunction([('flames', pow(1.1, 2))]), parse_prompt("(flames)++"))
self.assertEqual(make_weighted_conjunction([('flames', pow(0.9, 2))]), parse_prompt("(flames)--"))
self.assertEqual(make_weighted_conjunction([('flowers', pow(0.9, 3)), ('flames', pow(1.1, 3))]), parse_prompt("(flowers)--- flames+++"))
self.assertEqual(make_weighted_conjunction([('pretty flowers', 1.1)]),
parse_prompt("+(pretty flowers)"))
parse_prompt("(pretty flowers)+"))
self.assertEqual(make_weighted_conjunction([('pretty flowers', 1.1), (', the flames are too hot', 1)]),
parse_prompt("+(pretty flowers), the flames are too hot"))
parse_prompt("(pretty flowers)+, the flames are too hot"))
def test_no_parens_attention_runon(self):
self.assertEqual(make_weighted_conjunction([('fire', pow(1.1, 2)), ('flames', 1.0)]), parse_prompt("++fire flames"))
self.assertEqual(make_weighted_conjunction([('fire', pow(0.9, 2)), ('flames', 1.0)]), parse_prompt("--fire flames"))
self.assertEqual(make_weighted_conjunction([('flowers', 1.0), ('fire', pow(1.1, 2)), ('flames', 1.0)]), parse_prompt("flowers ++fire flames"))
self.assertEqual(make_weighted_conjunction([('flowers', 1.0), ('fire', pow(0.9, 2)), ('flames', 1.0)]), parse_prompt("flowers --fire flames"))
self.assertEqual(make_weighted_conjunction([('fire', 1.0), ('flames', pow(1.1, 2))]), parse_prompt("fire flames++"))
self.assertEqual(make_weighted_conjunction([('fire', 1.0), ('flames', pow(0.9, 2))]), parse_prompt("fire flames--"))
self.assertEqual(make_weighted_conjunction([('flowers', 1.0), ('fire', pow(1.1, 2)), ('flames', 1.0)]), parse_prompt("flowers fire++ flames"))
self.assertEqual(make_weighted_conjunction([('flowers', 1.0), ('fire', pow(0.9, 2)), ('flames', 1.0)]), parse_prompt("flowers fire-- flames"))
def test_explicit_conjunction(self):
@ -62,7 +63,7 @@ class PromptParserTestCase(unittest.TestCase):
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])]), parse_prompt('("fire", "flames").and()'))
self.assertEqual(
Conjunction([FlattenedPrompt([('fire flames', 1.0)]), FlattenedPrompt([('mountain man', 1.0)])]), parse_prompt('("fire flames", "mountain man").and()'))
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 2.0)]), FlattenedPrompt([('flames', 0.9)])]), parse_prompt('("2.0(fire)", "-flames").and()'))
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 2.0)]), FlattenedPrompt([('flames', 0.9)])]), parse_prompt('("(fire)2.0", "flames-").and()'))
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)]),
FlattenedPrompt([('mountain man', 1.0)])]), parse_prompt('("fire", "flames", "mountain man").and()'))
@ -75,8 +76,11 @@ class PromptParserTestCase(unittest.TestCase):
parse_prompt('("fire", "flames").and(2,1,2)')
def test_complex_conjunction(self):
#print(parse_prompt("a person with a hat (riding a bicycle.swap(skateboard))++"))
self.assertEqual(Conjunction([FlattenedPrompt([("mountain man", 1.0)]), FlattenedPrompt([("a person with a hat", 1.0), ("riding a bicycle", pow(1.1,2))])], weights=[0.5, 0.5]),
parse_prompt("(\"mountain man\", \"a person with a hat ++(riding a bicycle)\").and(0.5, 0.5)"))
parse_prompt("(\"mountain man\", \"a person with a hat (riding a bicycle)++\").and(0.5, 0.5)"))
self.assertEqual(Conjunction([FlattenedPrompt([("mountain man", 1.0)]),
FlattenedPrompt([("a person with a hat", 1.0),
("riding a", 1.1*1.1),
@ -85,7 +89,7 @@ class PromptParserTestCase(unittest.TestCase):
[Fragment("skateboard", pow(1.1,2))])
])
], weights=[0.5, 0.5]),
parse_prompt("(\"mountain man\", \"a person with a hat ++(riding a bicycle.swap(skateboard))\").and(0.5, 0.5)"))
parse_prompt("(\"mountain man\", \"a person with a hat (riding a bicycle.swap(skateboard))++\").and(0.5, 0.5)"))
def test_badly_formed(self):
def make_untouched_prompt(prompt):
@ -95,24 +99,25 @@ class PromptParserTestCase(unittest.TestCase):
self.assertEqual(make_untouched_prompt(prompt), parse_prompt(prompt))
assert_if_prompt_string_not_untouched('a test prompt')
assert_if_prompt_string_not_untouched('a badly formed test+ prompt')
# todo handle this
#assert_if_prompt_string_not_untouched('a badly formed +test prompt')
with self.assertRaises(pyparsing.ParseException):
parse_prompt('a badly (formed test prompt')
#with self.assertRaises(pyparsing.ParseException):
with self.assertRaises(pyparsing.ParseException):
parse_prompt('a badly (formed test+ prompt')
parse_prompt('a badly (formed +test prompt')
with self.assertRaises(pyparsing.ParseException):
parse_prompt('a badly (formed test+ )prompt')
parse_prompt('a badly (formed +test )prompt')
with self.assertRaises(pyparsing.ParseException):
parse_prompt('a badly (formed test+ )prompt')
parse_prompt('a badly (formed +test )prompt')
with self.assertRaises(pyparsing.ParseException):
parse_prompt('(((a badly (formed test+ )prompt')
parse_prompt('(((a badly (formed +test )prompt')
with self.assertRaises(pyparsing.ParseException):
parse_prompt('(a (ba)dly (f)ormed test+ prompt')
parse_prompt('(a (ba)dly (f)ormed +test prompt')
with self.assertRaises(pyparsing.ParseException):
parse_prompt('(a (ba)dly (f)ormed test+ +prompt')
parse_prompt('(a (ba)dly (f)ormed +test +prompt')
with self.assertRaises(pyparsing.ParseException):
parse_prompt('("((a badly (formed test+ ").blend(1.0)')
parse_prompt('("((a badly (formed +test ").blend(1.0)')
def test_blend(self):
@ -129,7 +134,7 @@ class PromptParserTestCase(unittest.TestCase):
FlattenedPrompt([('fire flames', 1.0), ('hot', pow(1.1, 2))]),
FlattenedPrompt([('hi', 1.0)])],
weights=[0.7, 0.3, 1.0])]),
parse_prompt("(\"fire\", \"fire flames ++(hot)\", \"hi\").blend(0.7, 0.3, 1.0)")
parse_prompt("(\"fire\", \"fire flames (hot)++\", \"hi\").blend(0.7, 0.3, 1.0)")
)
# blend a single entry is not a failure
self.assertEqual(Conjunction([Blend([FlattenedPrompt([('fire', 1.0)])], [0.7])]),
@ -156,17 +161,17 @@ class PromptParserTestCase(unittest.TestCase):
self.assertEqual(
Conjunction([Blend([FlattenedPrompt([('mountain, man, hairy', 1)]),
FlattenedPrompt([('face, teeth,', 1), ('eyes', 0.9*0.9)])], weights=[1.0,-1.0])]),
parse_prompt('("mountain, man, hairy", "face, teeth, --eyes").blend(1,-1)')
parse_prompt('("mountain, man, hairy", "face, teeth, eyes--").blend(1,-1)')
)
def test_nested(self):
self.assertEqual(make_weighted_conjunction([('fire', 1.0), ('flames', 2.0), ('trees', 3.0)]),
parse_prompt('fire 2.0(flames 1.5(trees))'))
parse_prompt('fire (flames (trees)1.5)2.0'))
self.assertEqual(Conjunction([Blend(prompts=[FlattenedPrompt([('fire', 1.0), ('flames', 1.2100000000000002)]),
FlattenedPrompt([('mountain', 1.0), ('man', 2.0)])],
weights=[1.0, 1.0])]),
parse_prompt('("fire ++(flames)", "mountain 2(man)").blend(1,1)'))
parse_prompt('("fire (flames)++", "mountain (man)2").blend(1,1)'))
def test_cross_attention_control(self):
@ -237,15 +242,15 @@ class PromptParserTestCase(unittest.TestCase):
flames_to_trees_fire = Conjunction([FlattenedPrompt([
CrossAttentionControlSubstitute([Fragment('flames',0.5)], [Fragment('trees',0.7)]),
Fragment(',', 1), Fragment('fire', 2.0)])])
self.assertEqual(flames_to_trees_fire, parse_prompt('"0.5(flames)".swap("0.7(trees)"), 2.0(fire)'))
self.assertEqual(flames_to_trees_fire, parse_prompt('"(flames)0.5".swap("(trees)0.7"), (fire)2.0'))
flames_to_trees_fire = Conjunction([FlattenedPrompt([
CrossAttentionControlSubstitute([Fragment('fire',0.5), Fragment('flames',0.25)], [Fragment('trees',0.7)]),
Fragment(',', 1), Fragment('fire', 2.0)])])
self.assertEqual(flames_to_trees_fire, parse_prompt('"0.5(fire 0.5(flames))".swap("0.7(trees)"), 2.0(fire)'))
self.assertEqual(flames_to_trees_fire, parse_prompt('"(fire (flames)0.5)0.5".swap("(trees)0.7"), (fire)2.0'))
flames_to_trees_fire = Conjunction([FlattenedPrompt([
CrossAttentionControlSubstitute([Fragment('fire',0.5), Fragment('flames',0.25)], [Fragment('trees',0.7), Fragment('houses', 1)]),
Fragment(',', 1), Fragment('fire', 2.0)])])
self.assertEqual(flames_to_trees_fire, parse_prompt('"0.5(fire 0.5(flames))".swap("0.7(trees) houses"), 2.0(fire)'))
self.assertEqual(flames_to_trees_fire, parse_prompt('"(fire (flames)0.5)0.5".swap("(trees)0.7 houses"), (fire)2.0'))
def test_cross_attention_control_options(self):
self.assertEqual(Conjunction([
@ -271,48 +276,48 @@ class PromptParserTestCase(unittest.TestCase):
self.assertEqual(make_basic_conjunction(['mountain (man)']),parse_prompt('mountain \(man\)'))
self.assertEqual(make_basic_conjunction(['mountain (man )']),parse_prompt('mountain (\(man)\)'))
self.assertEqual(make_basic_conjunction(['mountain (man)']),parse_prompt('mountain (\(man\))'))
self.assertEqual(make_weighted_conjunction([('mountain', 1), ('(man)', 1.1)]), parse_prompt('mountain +(\(man\))'))
self.assertEqual(make_weighted_conjunction([('mountain', 1), ('(man)', 1.1)]), parse_prompt('"mountain" +(\(man\))'))
self.assertEqual(make_weighted_conjunction([('"mountain"', 1), ('(man)', 1.1)]), parse_prompt('\\"mountain\\" +(\(man\))'))
self.assertEqual(make_weighted_conjunction([('mountain', 1), ('(man)', 1.1)]), parse_prompt('mountain (\(man\))+'))
self.assertEqual(make_weighted_conjunction([('mountain', 1), ('(man)', 1.1)]), parse_prompt('"mountain" (\(man\))+'))
self.assertEqual(make_weighted_conjunction([('"mountain"', 1), ('(man)', 1.1)]), parse_prompt('\\"mountain\\" (\(man\))+'))
# same weights for each are combined into one
self.assertEqual(make_weighted_conjunction([('"mountain" (man)', 1.1)]), parse_prompt('+(\\"mountain\\") +(\(man\))'))
self.assertEqual(make_weighted_conjunction([('"mountain"', 1.1), ('(man)', 0.9)]), parse_prompt('+(\\"mountain\\") -(\(man\))'))
self.assertEqual(make_weighted_conjunction([('"mountain" (man)', 1.1)]), parse_prompt('(\\"mountain\\")+ (\(man\))+'))
self.assertEqual(make_weighted_conjunction([('"mountain"', 1.1), ('(man)', 0.9)]), parse_prompt('(\\"mountain\\")+ (\(man\))-'))
self.assertEqual(make_weighted_conjunction([('mountain', 1), ('\(man\)', 1.1)]),parse_prompt('mountain 1.1(\(man\))'))
self.assertEqual(make_weighted_conjunction([('mountain', 1), ('\(man\)', 1.1)]),parse_prompt('"mountain" 1.1(\(man\))'))
self.assertEqual(make_weighted_conjunction([('"mountain"', 1), ('\(man\)', 1.1)]),parse_prompt('\\"mountain\\" 1.1(\(man\))'))
self.assertEqual(make_weighted_conjunction([('mountain', 1), ('\(man\)', 1.1)]),parse_prompt('mountain (\(man\))1.1'))
self.assertEqual(make_weighted_conjunction([('mountain', 1), ('\(man\)', 1.1)]),parse_prompt('"mountain" (\(man\))1.1'))
self.assertEqual(make_weighted_conjunction([('"mountain"', 1), ('\(man\)', 1.1)]),parse_prompt('\\"mountain\\" (\(man\))1.1'))
# same weights for each are combined into one
self.assertEqual(make_weighted_conjunction([('\\"mountain\\" \(man\)', 1.1)]),parse_prompt('+(\\"mountain\\") 1.1(\(man\))'))
self.assertEqual(make_weighted_conjunction([('\\"mountain\\"', 1.1), ('\(man\)', 0.9)]),parse_prompt('1.1(\\"mountain\\") 0.9(\(man\))'))
self.assertEqual(make_weighted_conjunction([('\\"mountain\\" \(man\)', 1.1)]),parse_prompt('(\\"mountain\\")+ (\(man\))1.1'))
self.assertEqual(make_weighted_conjunction([('\\"mountain\\"', 1.1), ('\(man\)', 0.9)]),parse_prompt('(\\"mountain\\")1.1 (\(man\))0.9'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain', 1.1), ('\(man\)', 1.1*1.1)]),parse_prompt('hairy +(mountain +(\(man\)))'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('\(man\)', 1.1*1.1), ('mountain', 1.1)]),parse_prompt('hairy +(1.1(\(man\)) "mountain")'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain', 1.1), ('\(man\)', 1.1*1.1)]),parse_prompt('hairy +("mountain" 1.1(\(man\)) )'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, man', 1.1)]),parse_prompt('hairy +("mountain, man")'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, man with a', 1.1), ('beard', 1.1*1.1)]), parse_prompt('hairy +("mountain, man" with a +beard)'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, man with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy +("mountain, man" with a 2.0(beard))'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man\" with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy +("mountain, \\"man\\"" with a 2.0(beard))'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, m\"an\" with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy +("mountain, m\\"an\\"" with a 2.0(beard))'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain', 1.1), ('\(man\)', 1.1*1.1)]),parse_prompt('hairy (mountain (\(man\))+)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('\(man\)', 1.1*1.1), ('mountain', 1.1)]),parse_prompt('hairy ((\(man\))1.1 "mountain")+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain', 1.1), ('\(man\)', 1.1*1.1)]),parse_prompt('hairy ("mountain" (\(man\))1.1 )+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, man', 1.1)]),parse_prompt('hairy ("mountain, man")+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, man with a', 1.1), ('beard', 1.1*1.1)]), parse_prompt('hairy ("mountain, man" with a beard+)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, man with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, man" with a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man\" with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\"man\\"" with a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, m\"an\" with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, m\\"an\\"" with a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man (with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy +("mountain, \\\"man\" \(with a 2.0(beard))'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man w(ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy +("mountain, \\\"man\" w\(ith a 2.0(beard))'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man with( a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy +("mountain, \\\"man\" with\( a 2.0(beard))'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man )with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy +("mountain, \\\"man\" \)with a 2.0(beard))'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy +("mountain, \\\"man\" w\)ith a 2.0(beard))'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man with) a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy +("mountain, \\\"man\" with\) a 2.0(beard))'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mou)ntain, \"man (wit(h a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy +("mou\)ntain, \\\"man\" \(wit\(h a 2.0(beard))'))
self.assertEqual(make_weighted_conjunction([('hai(ry', 1), ('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hai\(ry +("mountain, \\\"man\" w\)ith a 2.0(beard))'))
self.assertEqual(make_weighted_conjunction([('hairy((', 1), ('mountain, \"man with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy\(\( +("mountain, \\\"man\" with a 2.0(beard))'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man (with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" \(with a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man w(ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" w\(ith a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man with( a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" with\( a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man )with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" \)with a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" w\)ith a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man with) a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" with\) a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mou)ntain, \"man (wit(h a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mou\)ntain, \\\"man\" \(wit\(h a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hai(ry', 1), ('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hai\(ry ("mountain, \\\"man\" w\)ith a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy((', 1), ('mountain, \"man with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy\(\( ("mountain, \\\"man\" with a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('mountain, \"man (with a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('+("mountain, \\\"man\" \(with a 2.0(beard)) hairy'))
self.assertEqual(make_weighted_conjunction([('mountain, \"man w(ith a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('+("mountain, \\\"man\" w\(ith a 2.0(beard))hairy'))
self.assertEqual(make_weighted_conjunction([('mountain, \"man with( a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('+("mountain, \\\"man\" with\( a 2.0(beard)) hairy'))
self.assertEqual(make_weighted_conjunction([('mountain, \"man )with a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('+("mountain, \\\"man\" \)with a 2.0(beard)) hairy'))
self.assertEqual(make_weighted_conjunction([('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('+("mountain, \\\"man\" w\)ith a 2.0(beard)) hairy'))
self.assertEqual(make_weighted_conjunction([('mountain, \"man with) a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt(' +("mountain, \\\"man\" with\) a 2.0(beard)) hairy'))
self.assertEqual(make_weighted_conjunction([('mou)ntain, \"man (wit(h a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('+("mou\)ntain, \\\"man\" \(wit\(h a 2.0(beard)) hairy'))
self.assertEqual(make_weighted_conjunction([('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0), ('hai(ry', 1)]), parse_prompt('+("mountain, \\\"man\" w\)ith a 2.0(beard)) hai\(ry '))
self.assertEqual(make_weighted_conjunction([('mountain, \"man with a', 1.1), ('beard', 1.1*2.0), ('hairy((', 1)]), parse_prompt('+("mountain, \\\"man\" with a 2.0(beard)) hairy\(\( '))
self.assertEqual(make_weighted_conjunction([('mountain, \"man (with a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" \(with a (beard)2.0)+ hairy'))
self.assertEqual(make_weighted_conjunction([('mountain, \"man w(ith a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" w\(ith a (beard)2.0)+hairy'))
self.assertEqual(make_weighted_conjunction([('mountain, \"man with( a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" with\( a (beard)2.0)+ hairy'))
self.assertEqual(make_weighted_conjunction([('mountain, \"man )with a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" \)with a (beard)2.0)+ hairy'))
self.assertEqual(make_weighted_conjunction([('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" w\)ith a (beard)2.0)+ hairy'))
self.assertEqual(make_weighted_conjunction([('mountain, \"man with) a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt(' ("mountain, \\\"man\" with\) a (beard)2.0)+ hairy'))
self.assertEqual(make_weighted_conjunction([('mou)ntain, \"man (wit(h a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mou\)ntain, \\\"man\" \(wit\(h a (beard)2.0)+ hairy'))
self.assertEqual(make_weighted_conjunction([('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0), ('hai(ry', 1)]), parse_prompt('("mountain, \\\"man\" w\)ith a (beard)2.0)+ hai\(ry '))
self.assertEqual(make_weighted_conjunction([('mountain, \"man with a', 1.1), ('beard', 1.1*2.0), ('hairy((', 1)]), parse_prompt('("mountain, \\\"man\" with a (beard)2.0)+ hairy\(\( '))
def test_cross_attention_escaping(self):
@ -339,7 +344,10 @@ class PromptParserTestCase(unittest.TestCase):
parse_prompt('mountain (\(\(\().swap(m\(on\)\)key)'))
def test_single(self):
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mou)ntain, \"man (wit(h a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy +("mou\)ntain, \\\"man\" \(wit\(h a 2.0(beard))'))
# todo handle this
#self.assertEqual(make_basic_conjunction(['a badly formed +test prompt']),
# parse_prompt('a badly formed +test prompt'))
pass
if __name__ == '__main__':