parser working with basic escapes

This commit is contained in:
Damian at mba
2022-10-20 16:56:34 +02:00
parent da223dfe81
commit 79b4afeae7
2 changed files with 50 additions and 19 deletions

View File

@ -174,7 +174,7 @@ class PromptParserTestCase(unittest.TestCase):
CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]),
parse_prompt('a forest landscape "".swap("in winter")'))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
CrossAttentionControlSubstitute([Fragment(' ',1)], [Fragment('in winter',1)])])]),
CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]),
parse_prompt('a forest landscape " ".swap("in winter")'))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
@ -184,7 +184,7 @@ class PromptParserTestCase(unittest.TestCase):
CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]),
parse_prompt('a forest landscape "in winter".swap()'))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment(' ',1)])])]),
CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]),
parse_prompt('a forest landscape "in winter".swap(" ")'))
def test_cross_attention_control_with_attention(self):
@ -201,8 +201,21 @@ class PromptParserTestCase(unittest.TestCase):
Fragment(',', 1), Fragment('fire', 2.0)])])
self.assertEqual(flames_to_trees_fire, parse_prompt('"0.5(fire 0.5(flames))".swap("0.7(trees) houses"), 2.0(fire)'))
def test_single(self):
print(parse_prompt('fire (trees and houses).swap("flames")'))
def make_basic_conjunction(self, strings: list[str]):
fragments = [Fragment(x) for x in strings]
return Conjunction([FlattenedPrompt(fragments)])
def make_weighted_conjunction(self, weighted_strings: list[tuple[str,float]]):
fragments = [Fragment(x, w) for x,w in weighted_strings]
return Conjunction([FlattenedPrompt(fragments)])
def test_escaping(self):
self.assertEqual(self.make_basic_conjunction(['mountain \(man\)']),parse_prompt('mountain \(man\)'))
self.assertEqual(self.make_basic_conjunction(['mountain (\(man)\)']),parse_prompt('mountain (\(man)\)'))
self.assertEqual(self.make_basic_conjunction(['mountain (\(man\))']),parse_prompt('mountain (\(man\))'))
#self.assertEqual(self.make_weighted_conjunction([('mountain', 1), ('\(man\)', 1.1)]),parse_prompt('mountain +(\(man\))'))
if __name__ == '__main__':