From 79b4afeae7f0e2e4f922de26e6d9a458fea5c46b Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Thu, 20 Oct 2022 16:56:34 +0200 Subject: [PATCH] parser working with basic escapes --- ldm/invoke/prompt_parser.py | 48 +++++++++++++++++++++++++------------ tests/test_prompt_parser.py | 21 ++++++++++++---- 2 files changed, 50 insertions(+), 19 deletions(-) diff --git a/ldm/invoke/prompt_parser.py b/ldm/invoke/prompt_parser.py index abd9ce726c..398a596c7e 100644 --- a/ldm/invoke/prompt_parser.py +++ b/ldm/invoke/prompt_parser.py @@ -169,12 +169,16 @@ class PromptParser(): def flatten(self, root: Conjunction): + print("flattening", root) + def fuse_fragments(items): # print("fusing fragments in ", items) result = [] for x in items: - if issubclass(type(x), CrossAttentionControlledFragment): - result.append(x) + if type(x) is CrossAttentionControlSubstitute: + original_fused = fuse_fragments(x.original) + edited_fused = fuse_fragments(x.edited) + result.append(CrossAttentionControlSubstitute(original_fused, edited_fused)) else: last_weight = result[-1].weight \ if (len(result) > 0 and not issubclass(type(result[-1]), CrossAttentionControlledFragment)) \ @@ -221,10 +225,9 @@ class PromptParser(): #print(prefix + "after flattening Prompt, results is", results) else: raise PromptParser.ParsingException(f"unhandled node type {type(node)} when flattening {node}") - print(prefix + "-> after flattening", type(node), "results is", results) + print(prefix + "-> after flattening", type(node).__name__, "results is", results) return results - #print("flattening", root) flattened_parts = [] for part in root.prompts: @@ -244,6 +247,8 @@ class PromptParser(): number = pyparsing.pyparsing_common.real | pp.Word(pp.nums).set_parse_action(pp.token_map(float)) SPACE_CHARS = string.whitespace + attention = pp.Forward() + def make_fragment(x): #print("### making fragment for", x) if type(x) is str: @@ -254,33 +259,44 @@ class PromptParser(): else: raise PromptParser.ParsingException("Cannot make fragment from " + str(x)) + unquoted_fragment = pp.Forward() + quoted_fragment = pp.Forward() + parenthesized_fragment = pp.Forward() + def parse_fragment_str(x): - return make_fragment(x) + print("parsing", x) + if len(x[0].strip()) == 0: + return Fragment('') + fragment_parser = pp.Group(pp.OneOrMore(attention | pp.Word(pp.printables, exclude_chars=string.whitespace).set_parse_action(make_fragment))) + fragment_parser.set_name('word_or_attention') + result = fragment_parser.parse_string(x[0]) + #result = (pp.OneOrMore(attention | unquoted_fragment) + pp.StringEnd()).parse_string(x[0]) + print("parsed to", result) + return result - quoted_fragment = pp.QuotedString(quote_char='"', esc_char='\\') - quoted_fragment.set_parse_action(parse_fragment_str).set_name('quoted_fragment') + quoted_fragment << pp.QuotedString(quote_char='"', esc_char='\\') + quoted_fragment.set_parse_action(make_fragment).set_name('quoted_fragment') - unquoted_fragment = pp.Combine(pp.OneOrMore( + unquoted_fragment << pp.Combine(pp.OneOrMore( pp.Literal('\\"').set_debug(False) | pp.Literal('\\').set_debug(False) | pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"') )) - unquoted_fragment.set_parse_action(parse_fragment_str).set_name('unquoted_fragment') + unquoted_fragment.set_parse_action(make_fragment).set_name('unquoted_fragment') - parenthesized_fragment = \ - (lparen + quoted_fragment.set_debug(True) + rparen).set_name('quoted_paren_internal') | \ - (lparen + rparen).set_parse_action(lambda x: make_fragment('')) | \ + parenthesized_fragment << pp.Or([ + (lparen + quoted_fragment.set_parse_action(parse_fragment_str).set_debug(True) + rparen).set_name('-quoted_paren_internal').set_debug(True), + (lparen + rparen).set_parse_action(lambda x: make_fragment('')).set_name('-()').set_debug(True), (lparen + pp.Combine(pp.OneOrMore( pp.Literal('\\)').set_debug(False) | pp.Literal('\\').set_debug(False) | pp.Word(pp.printables, exclude_chars=string.whitespace + '\\)') | pp.Word(string.whitespace) - )) + rparen).set_parse_action(parse_fragment_str).set_name('unquoted_paren_internal').set_debug(True) + )).set_name('--combined').set_parse_action(parse_fragment_str).set_debug(True) + rparen)]).set_name('-unquoted_paren_internal').set_debug(True) parenthesized_fragment.set_name('parenthesized_fragment').set_debug(True) # attention control of the form +(phrase) / -(phrase) / (phrase) # phrase can be multiple words, can have multiple +/- signs to increase the effect or type a floating point or integer weight - attention = pp.Forward() attention_head = (number | pp.Word('+') | pp.Word('-'))\ .set_name("attention_head")\ .set_debug(False) @@ -352,7 +368,9 @@ class PromptParser(): prompt_part.set_debug(False) prompt_part.set_name("prompt_part") - empty = ((lparen + rparen) | (quotes + quotes)).suppress() + empty = ( + (lparen + pp.ZeroOrMore(pp.Word(string.whitespace)) + rparen) | + (quotes + pp.ZeroOrMore(pp.Word(string.whitespace)) + quotes)).set_debug(False).set_name('empty') # root prompt definition prompt = (pp.Group(pp.OneOrMore(prompt_part) | empty) + pp.StringEnd()) \ diff --git a/tests/test_prompt_parser.py b/tests/test_prompt_parser.py index 0aa0cfd6ae..38a24ca529 100644 --- a/tests/test_prompt_parser.py +++ b/tests/test_prompt_parser.py @@ -174,7 +174,7 @@ class PromptParserTestCase(unittest.TestCase): CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]), parse_prompt('a forest landscape "".swap("in winter")')) self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1), - CrossAttentionControlSubstitute([Fragment(' ',1)], [Fragment('in winter',1)])])]), + CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]), parse_prompt('a forest landscape " ".swap("in winter")')) self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1), @@ -184,7 +184,7 @@ class PromptParserTestCase(unittest.TestCase): CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]), parse_prompt('a forest landscape "in winter".swap()')) self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1), - CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment(' ',1)])])]), + CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]), parse_prompt('a forest landscape "in winter".swap(" ")')) def test_cross_attention_control_with_attention(self): @@ -201,8 +201,21 @@ class PromptParserTestCase(unittest.TestCase): Fragment(',', 1), Fragment('fire', 2.0)])]) self.assertEqual(flames_to_trees_fire, parse_prompt('"0.5(fire 0.5(flames))".swap("0.7(trees) houses"), 2.0(fire)')) - def test_single(self): - print(parse_prompt('fire (trees and houses).swap("flames")')) + + def make_basic_conjunction(self, strings: list[str]): + fragments = [Fragment(x) for x in strings] + return Conjunction([FlattenedPrompt(fragments)]) + + def make_weighted_conjunction(self, weighted_strings: list[tuple[str,float]]): + fragments = [Fragment(x, w) for x,w in weighted_strings] + return Conjunction([FlattenedPrompt(fragments)]) + + + def test_escaping(self): + self.assertEqual(self.make_basic_conjunction(['mountain \(man\)']),parse_prompt('mountain \(man\)')) + self.assertEqual(self.make_basic_conjunction(['mountain (\(man)\)']),parse_prompt('mountain (\(man)\)')) + self.assertEqual(self.make_basic_conjunction(['mountain (\(man\))']),parse_prompt('mountain (\(man\))')) + #self.assertEqual(self.make_weighted_conjunction([('mountain', 1), ('\(man\)', 1.1)]),parse_prompt('mountain +(\(man\))')) if __name__ == '__main__':