diff --git a/ldm/invoke/prompt_parser.py b/ldm/invoke/prompt_parser.py index 4a6d470140..7e9428d820 100644 --- a/ldm/invoke/prompt_parser.py +++ b/ldm/invoke/prompt_parser.py @@ -417,7 +417,7 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float) #print("parsed to", result) return result except pp.ParseException as e: - print("parse_fragment_str couldn't parse prompt string:", e) + #print("parse_fragment_str couldn't parse prompt string:", e) raise quoted_fragment << pp.QuotedString(quote_char='"', esc_char=None, esc_quote='\\"') @@ -445,14 +445,17 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float) unquoted_word.set_parse_action(make_text_fragment).set_name('unquoted_word').set_debug(False) #print(unquoted_fragment.parse_string("cat.swap(dog)")) - parenthesized_fragment << pp.Or([ - (lparen + quoted_fragment.copy().set_parse_action(lambda x: parse_fragment_str(x, in_quotes=True)).set_debug(False) + rparen).set_name('-quoted_paren_internal').set_debug(False), - (lparen + rparen).set_parse_action(lambda x: make_text_fragment('')).set_name('-()').set_debug(False), - (lparen + pp.Combine(pp.OneOrMore( + parenthesized_fragment << (lparen + + pp.Or([ + (parenthesized_fragment), + (quoted_fragment.copy().set_parse_action(lambda x: parse_fragment_str(x, in_quotes=True)).set_debug(False)).set_name('-quoted_paren_internal').set_debug(False), + (pp.Combine(pp.OneOrMore( escaped_quote | escaped_lparen | escaped_rparen | escaped_backslash | pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"()') | pp.Word(string.whitespace) - )).set_name('--combined').set_parse_action(lambda x: parse_fragment_str(x, in_parens=True)).set_debug(False) + rparen)]).set_name('-unquoted_paren_internal').set_debug(False) + )).set_name('--combined').set_parse_action(lambda x: parse_fragment_str(x, in_parens=True)).set_debug(False)), + pp.Empty() + ]) + rparen) parenthesized_fragment.set_name('parenthesized_fragment').set_debug(False) debug_attention = False @@ -472,11 +475,12 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float) + rparen + attention_with_parens_foot) attention_with_parens.set_name('attention_with_parens').set_debug(debug_attention) - attention_without_parens_foot = pp.Or(pp.Word('+') | pp.Word('-')).set_name('attention_without_parens_foots') - attention_without_parens <<= pp.Group( - (quoted_fragment.copy().set_name('attention_quoted_fragment_without_parens').set_debug(debug_attention) + attention_without_parens_foot) | + attention_without_parens_foot = pp.NotAny(pp.White()) + pp.Or(pp.Word('+') | pp.Word('-')).set_name('attention_without_parens_foots') + attention_without_parens <<= pp.Group(pp.MatchFirst([ + quoted_fragment.copy().set_name('attention_quoted_fragment_without_parens').set_debug(debug_attention) + attention_without_parens_foot, pp.Combine(build_escaped_word_parser_charbychar('()+-')).set_name('attention_word_without_parens').set_debug(debug_attention)#.set_parse_action(lambda x: print('escapéd', x)) - + attention_without_parens_foot)#.leave_whitespace() + + attention_without_parens_foot#.leave_whitespace() + ])) attention_without_parens.set_name('attention_without_parens').set_debug(debug_attention) @@ -553,7 +557,7 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float) prompt = (pp.OneOrMore(pp.Or([cross_attention_substitute.set_debug(debug_root_prompt), attention.set_debug(debug_root_prompt), quoted_fragment.set_debug(debug_root_prompt), - (lparen + (pp.ZeroOrMore(unquoted_word | pp.White().suppress()).leave_whitespace()) + rparen).set_name('parenthesized-uqw').set_debug(debug_root_prompt), + parenthesized_fragment.set_debug(debug_root_prompt), unquoted_word.set_debug(debug_root_prompt), empty.set_parse_action(make_text_fragment).set_debug(debug_root_prompt)]) ) + pp.StringEnd()) \ diff --git a/tests/test_prompt_parser.py b/tests/test_prompt_parser.py index 839093289a..7a82dca2b1 100644 --- a/tests/test_prompt_parser.py +++ b/tests/test_prompt_parser.py @@ -117,6 +117,15 @@ class PromptParserTestCase(unittest.TestCase): with self.assertRaises(pyparsing.ParseException): parse_prompt('("((a badly (formed +test ").blend(1.0)') + self.assertEqual(Conjunction([FlattenedPrompt([Fragment('hamburger bun', 1)])]), + parse_prompt("hamburger ((bun))")) + self.assertEqual(Conjunction([FlattenedPrompt([Fragment('hamburger bun', 1)])]), + parse_prompt("hamburger (bun)")) + self.assertEqual(Conjunction([FlattenedPrompt([Fragment('hamburger kaiser roll', 1)])]), + parse_prompt("hamburger (kaiser roll)")) + self.assertEqual(Conjunction([FlattenedPrompt([Fragment('hamburger kaiser roll', 1)])]), + parse_prompt("hamburger ((kaiser roll))")) + def test_blend(self): self.assertEqual(Conjunction( @@ -284,6 +293,7 @@ class PromptParserTestCase(unittest.TestCase): ])]), parse_prompt("a cat.swap(dog) eating a hotdog.swap(h\(o\)tdog-, shape_freedom=0.5)")) + def test_cross_attention_control_options(self): self.assertEqual(Conjunction([ FlattenedPrompt([Fragment('a', 1), @@ -426,9 +436,9 @@ class PromptParserTestCase(unittest.TestCase): # todo handle this #self.assertEqual(make_basic_conjunction(['a badly formed +test prompt']), # parse_prompt('a badly formed +test prompt')) - self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1), - CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]), - parse_prompt('a forest landscape "in winter".swap()')) + trees_and_houses_to_flames = Conjunction([FlattenedPrompt([('fire', 1.0), \ + CrossAttentionControlSubstitute([Fragment('trees and houses', 1)], [Fragment('flames',1)])])]) + self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire (trees and houses).swap("flames")')) pass