be more forgiving about prompts with ((words))

This commit is contained in:
Damian at mba 2022-10-27 22:36:33 +02:00
parent 30745f163d
commit 245cf606a3
2 changed files with 28 additions and 14 deletions

View File

@ -417,7 +417,7 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float)
#print("parsed to", result) #print("parsed to", result)
return result return result
except pp.ParseException as e: except pp.ParseException as e:
print("parse_fragment_str couldn't parse prompt string:", e) #print("parse_fragment_str couldn't parse prompt string:", e)
raise raise
quoted_fragment << pp.QuotedString(quote_char='"', esc_char=None, esc_quote='\\"') quoted_fragment << pp.QuotedString(quote_char='"', esc_char=None, esc_quote='\\"')
@ -445,14 +445,17 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float)
unquoted_word.set_parse_action(make_text_fragment).set_name('unquoted_word').set_debug(False) unquoted_word.set_parse_action(make_text_fragment).set_name('unquoted_word').set_debug(False)
#print(unquoted_fragment.parse_string("cat.swap(dog)")) #print(unquoted_fragment.parse_string("cat.swap(dog)"))
parenthesized_fragment << pp.Or([ parenthesized_fragment << (lparen +
(lparen + quoted_fragment.copy().set_parse_action(lambda x: parse_fragment_str(x, in_quotes=True)).set_debug(False) + rparen).set_name('-quoted_paren_internal').set_debug(False), pp.Or([
(lparen + rparen).set_parse_action(lambda x: make_text_fragment('')).set_name('-()').set_debug(False), (parenthesized_fragment),
(lparen + pp.Combine(pp.OneOrMore( (quoted_fragment.copy().set_parse_action(lambda x: parse_fragment_str(x, in_quotes=True)).set_debug(False)).set_name('-quoted_paren_internal').set_debug(False),
(pp.Combine(pp.OneOrMore(
escaped_quote | escaped_lparen | escaped_rparen | escaped_backslash | escaped_quote | escaped_lparen | escaped_rparen | escaped_backslash |
pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"()') | pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"()') |
pp.Word(string.whitespace) pp.Word(string.whitespace)
)).set_name('--combined').set_parse_action(lambda x: parse_fragment_str(x, in_parens=True)).set_debug(False) + rparen)]).set_name('-unquoted_paren_internal').set_debug(False) )).set_name('--combined').set_parse_action(lambda x: parse_fragment_str(x, in_parens=True)).set_debug(False)),
pp.Empty()
]) + rparen)
parenthesized_fragment.set_name('parenthesized_fragment').set_debug(False) parenthesized_fragment.set_name('parenthesized_fragment').set_debug(False)
debug_attention = False debug_attention = False
@ -472,11 +475,12 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float)
+ rparen + attention_with_parens_foot) + rparen + attention_with_parens_foot)
attention_with_parens.set_name('attention_with_parens').set_debug(debug_attention) attention_with_parens.set_name('attention_with_parens').set_debug(debug_attention)
attention_without_parens_foot = pp.Or(pp.Word('+') | pp.Word('-')).set_name('attention_without_parens_foots') attention_without_parens_foot = pp.NotAny(pp.White()) + pp.Or(pp.Word('+') | pp.Word('-')).set_name('attention_without_parens_foots')
attention_without_parens <<= pp.Group( attention_without_parens <<= pp.Group(pp.MatchFirst([
(quoted_fragment.copy().set_name('attention_quoted_fragment_without_parens').set_debug(debug_attention) + attention_without_parens_foot) | quoted_fragment.copy().set_name('attention_quoted_fragment_without_parens').set_debug(debug_attention) + attention_without_parens_foot,
pp.Combine(build_escaped_word_parser_charbychar('()+-')).set_name('attention_word_without_parens').set_debug(debug_attention)#.set_parse_action(lambda x: print('escapéd', x)) pp.Combine(build_escaped_word_parser_charbychar('()+-')).set_name('attention_word_without_parens').set_debug(debug_attention)#.set_parse_action(lambda x: print('escapéd', x))
+ attention_without_parens_foot)#.leave_whitespace() + attention_without_parens_foot#.leave_whitespace()
]))
attention_without_parens.set_name('attention_without_parens').set_debug(debug_attention) attention_without_parens.set_name('attention_without_parens').set_debug(debug_attention)
@ -553,7 +557,7 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float)
prompt = (pp.OneOrMore(pp.Or([cross_attention_substitute.set_debug(debug_root_prompt), prompt = (pp.OneOrMore(pp.Or([cross_attention_substitute.set_debug(debug_root_prompt),
attention.set_debug(debug_root_prompt), attention.set_debug(debug_root_prompt),
quoted_fragment.set_debug(debug_root_prompt), quoted_fragment.set_debug(debug_root_prompt),
(lparen + (pp.ZeroOrMore(unquoted_word | pp.White().suppress()).leave_whitespace()) + rparen).set_name('parenthesized-uqw').set_debug(debug_root_prompt), parenthesized_fragment.set_debug(debug_root_prompt),
unquoted_word.set_debug(debug_root_prompt), unquoted_word.set_debug(debug_root_prompt),
empty.set_parse_action(make_text_fragment).set_debug(debug_root_prompt)]) empty.set_parse_action(make_text_fragment).set_debug(debug_root_prompt)])
) + pp.StringEnd()) \ ) + pp.StringEnd()) \

View File

@ -117,6 +117,15 @@ class PromptParserTestCase(unittest.TestCase):
with self.assertRaises(pyparsing.ParseException): with self.assertRaises(pyparsing.ParseException):
parse_prompt('("((a badly (formed +test ").blend(1.0)') parse_prompt('("((a badly (formed +test ").blend(1.0)')
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('hamburger bun', 1)])]),
parse_prompt("hamburger ((bun))"))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('hamburger bun', 1)])]),
parse_prompt("hamburger (bun)"))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('hamburger kaiser roll', 1)])]),
parse_prompt("hamburger (kaiser roll)"))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('hamburger kaiser roll', 1)])]),
parse_prompt("hamburger ((kaiser roll))"))
def test_blend(self): def test_blend(self):
self.assertEqual(Conjunction( self.assertEqual(Conjunction(
@ -284,6 +293,7 @@ class PromptParserTestCase(unittest.TestCase):
])]), ])]),
parse_prompt("a cat.swap(dog) eating a hotdog.swap(h\(o\)tdog-, shape_freedom=0.5)")) parse_prompt("a cat.swap(dog) eating a hotdog.swap(h\(o\)tdog-, shape_freedom=0.5)"))
def test_cross_attention_control_options(self): def test_cross_attention_control_options(self):
self.assertEqual(Conjunction([ self.assertEqual(Conjunction([
FlattenedPrompt([Fragment('a', 1), FlattenedPrompt([Fragment('a', 1),
@ -426,9 +436,9 @@ class PromptParserTestCase(unittest.TestCase):
# todo handle this # todo handle this
#self.assertEqual(make_basic_conjunction(['a badly formed +test prompt']), #self.assertEqual(make_basic_conjunction(['a badly formed +test prompt']),
# parse_prompt('a badly formed +test prompt')) # parse_prompt('a badly formed +test prompt'))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1), trees_and_houses_to_flames = Conjunction([FlattenedPrompt([('fire', 1.0), \
CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]), CrossAttentionControlSubstitute([Fragment('trees and houses', 1)], [Fragment('flames',1)])])])
parse_prompt('a forest landscape "in winter".swap()')) self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire (trees and houses).swap("flames")'))
pass pass