ok now we're cooking

This commit is contained in:
Damian at mba
2022-10-21 03:29:50 +02:00
parent da75876639
commit 2e0b1c4c8b
2 changed files with 273 additions and 202 deletions

View File

@ -77,6 +77,15 @@ class PromptParserTestCase(unittest.TestCase):
def test_complex_conjunction(self):
self.assertEqual(Conjunction([FlattenedPrompt([("mountain man", 1.0)]), FlattenedPrompt([("a person with a hat", 1.0), ("riding a bicycle", pow(1.1,2))])], weights=[0.5, 0.5]),
parse_prompt("(\"mountain man\", \"a person with a hat ++(riding a bicycle)\").and(0.5, 0.5)"))
self.assertEqual(Conjunction([FlattenedPrompt([("mountain man", 1.0)]),
FlattenedPrompt([("a person with a hat", 1.0),
("riding a", 1.1*1.1),
CrossAttentionControlSubstitute(
[Fragment("bicycle", pow(1.1,2))],
[Fragment("skateboard", pow(1.1,2))])
])
], weights=[0.5, 0.5]),
parse_prompt("(\"mountain man\", \"a person with a hat ++(riding a bicycle.swap(skateboard))\").and(0.5, 0.5)"))
def test_badly_formed(self):
def make_untouched_prompt(prompt):
@ -309,8 +318,7 @@ class PromptParserTestCase(unittest.TestCase):
parse_prompt('mountain (\(\(\().swap(m\(on\)\)key)'))
def test_single(self):
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('"man', 1)], [Fragment('monkey', 1)])])]),
parse_prompt('mountain (\\"man).swap("monkey")'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mou)ntain, \"man (wit(h a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy +("mou\)ntain, \\\"man\" \(wit\(h a 2.0(beard))'))
if __name__ == '__main__':