parser working with basic escapes

This commit is contained in:
Damian at mba 2022-10-20 16:56:34 +02:00
parent da223dfe81
commit 79b4afeae7
2 changed files with 50 additions and 19 deletions

View File

@ -169,12 +169,16 @@ class PromptParser():
def flatten(self, root: Conjunction):
print("flattening", root)
def fuse_fragments(items):
# print("fusing fragments in ", items)
result = []
for x in items:
if issubclass(type(x), CrossAttentionControlledFragment):
result.append(x)
if type(x) is CrossAttentionControlSubstitute:
original_fused = fuse_fragments(x.original)
edited_fused = fuse_fragments(x.edited)
result.append(CrossAttentionControlSubstitute(original_fused, edited_fused))
else:
last_weight = result[-1].weight \
if (len(result) > 0 and not issubclass(type(result[-1]), CrossAttentionControlledFragment)) \
@ -221,10 +225,9 @@ class PromptParser():
#print(prefix + "after flattening Prompt, results is", results)
else:
raise PromptParser.ParsingException(f"unhandled node type {type(node)} when flattening {node}")
print(prefix + "-> after flattening", type(node), "results is", results)
print(prefix + "-> after flattening", type(node).__name__, "results is", results)
return results
#print("flattening", root)
flattened_parts = []
for part in root.prompts:
@ -244,6 +247,8 @@ class PromptParser():
number = pyparsing.pyparsing_common.real | pp.Word(pp.nums).set_parse_action(pp.token_map(float))
SPACE_CHARS = string.whitespace
attention = pp.Forward()
def make_fragment(x):
#print("### making fragment for", x)
if type(x) is str:
@ -254,33 +259,44 @@ class PromptParser():
else:
raise PromptParser.ParsingException("Cannot make fragment from " + str(x))
unquoted_fragment = pp.Forward()
quoted_fragment = pp.Forward()
parenthesized_fragment = pp.Forward()
def parse_fragment_str(x):
return make_fragment(x)
print("parsing", x)
if len(x[0].strip()) == 0:
return Fragment('')
fragment_parser = pp.Group(pp.OneOrMore(attention | pp.Word(pp.printables, exclude_chars=string.whitespace).set_parse_action(make_fragment)))
fragment_parser.set_name('word_or_attention')
result = fragment_parser.parse_string(x[0])
#result = (pp.OneOrMore(attention | unquoted_fragment) + pp.StringEnd()).parse_string(x[0])
print("parsed to", result)
return result
quoted_fragment = pp.QuotedString(quote_char='"', esc_char='\\')
quoted_fragment.set_parse_action(parse_fragment_str).set_name('quoted_fragment')
quoted_fragment << pp.QuotedString(quote_char='"', esc_char='\\')
quoted_fragment.set_parse_action(make_fragment).set_name('quoted_fragment')
unquoted_fragment = pp.Combine(pp.OneOrMore(
unquoted_fragment << pp.Combine(pp.OneOrMore(
pp.Literal('\\"').set_debug(False) |
pp.Literal('\\').set_debug(False) |
pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"')
))
unquoted_fragment.set_parse_action(parse_fragment_str).set_name('unquoted_fragment')
unquoted_fragment.set_parse_action(make_fragment).set_name('unquoted_fragment')
parenthesized_fragment = \
(lparen + quoted_fragment.set_debug(True) + rparen).set_name('quoted_paren_internal') | \
(lparen + rparen).set_parse_action(lambda x: make_fragment('')) | \
parenthesized_fragment << pp.Or([
(lparen + quoted_fragment.set_parse_action(parse_fragment_str).set_debug(True) + rparen).set_name('-quoted_paren_internal').set_debug(True),
(lparen + rparen).set_parse_action(lambda x: make_fragment('')).set_name('-()').set_debug(True),
(lparen + pp.Combine(pp.OneOrMore(
pp.Literal('\\)').set_debug(False) |
pp.Literal('\\').set_debug(False) |
pp.Word(pp.printables, exclude_chars=string.whitespace + '\\)') |
pp.Word(string.whitespace)
)) + rparen).set_parse_action(parse_fragment_str).set_name('unquoted_paren_internal').set_debug(True)
)).set_name('--combined').set_parse_action(parse_fragment_str).set_debug(True) + rparen)]).set_name('-unquoted_paren_internal').set_debug(True)
parenthesized_fragment.set_name('parenthesized_fragment').set_debug(True)
# attention control of the form +(phrase) / -(phrase) / <weight>(phrase)
# phrase can be multiple words, can have multiple +/- signs to increase the effect or type a floating point or integer weight
attention = pp.Forward()
attention_head = (number | pp.Word('+') | pp.Word('-'))\
.set_name("attention_head")\
.set_debug(False)
@ -352,7 +368,9 @@ class PromptParser():
prompt_part.set_debug(False)
prompt_part.set_name("prompt_part")
empty = ((lparen + rparen) | (quotes + quotes)).suppress()
empty = (
(lparen + pp.ZeroOrMore(pp.Word(string.whitespace)) + rparen) |
(quotes + pp.ZeroOrMore(pp.Word(string.whitespace)) + quotes)).set_debug(False).set_name('empty')
# root prompt definition
prompt = (pp.Group(pp.OneOrMore(prompt_part) | empty) + pp.StringEnd()) \

View File

@ -174,7 +174,7 @@ class PromptParserTestCase(unittest.TestCase):
CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]),
parse_prompt('a forest landscape "".swap("in winter")'))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
CrossAttentionControlSubstitute([Fragment(' ',1)], [Fragment('in winter',1)])])]),
CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]),
parse_prompt('a forest landscape " ".swap("in winter")'))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
@ -184,7 +184,7 @@ class PromptParserTestCase(unittest.TestCase):
CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]),
parse_prompt('a forest landscape "in winter".swap()'))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment(' ',1)])])]),
CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]),
parse_prompt('a forest landscape "in winter".swap(" ")'))
def test_cross_attention_control_with_attention(self):
@ -201,8 +201,21 @@ class PromptParserTestCase(unittest.TestCase):
Fragment(',', 1), Fragment('fire', 2.0)])])
self.assertEqual(flames_to_trees_fire, parse_prompt('"0.5(fire 0.5(flames))".swap("0.7(trees) houses"), 2.0(fire)'))
def test_single(self):
print(parse_prompt('fire (trees and houses).swap("flames")'))
def make_basic_conjunction(self, strings: list[str]):
fragments = [Fragment(x) for x in strings]
return Conjunction([FlattenedPrompt(fragments)])
def make_weighted_conjunction(self, weighted_strings: list[tuple[str,float]]):
fragments = [Fragment(x, w) for x,w in weighted_strings]
return Conjunction([FlattenedPrompt(fragments)])
def test_escaping(self):
self.assertEqual(self.make_basic_conjunction(['mountain \(man\)']),parse_prompt('mountain \(man\)'))
self.assertEqual(self.make_basic_conjunction(['mountain (\(man)\)']),parse_prompt('mountain (\(man)\)'))
self.assertEqual(self.make_basic_conjunction(['mountain (\(man\))']),parse_prompt('mountain (\(man\))'))
#self.assertEqual(self.make_weighted_conjunction([('mountain', 1), ('\(man\)', 1.1)]),parse_prompt('mountain +(\(man\))'))
if __name__ == '__main__':