From da223dfe819efe96b3be99c158c08c96c619337a Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Thu, 20 Oct 2022 15:56:46 +0200 Subject: [PATCH] wip re-writing parts of prompt parser --- ldm/invoke/prompt_parser.py | 80 +++++++++++++++++++++++-------------- tests/test_prompt_parser.py | 3 ++ 2 files changed, 52 insertions(+), 31 deletions(-) diff --git a/ldm/invoke/prompt_parser.py b/ldm/invoke/prompt_parser.py index c13175a488..abd9ce726c 100644 --- a/ldm/invoke/prompt_parser.py +++ b/ldm/invoke/prompt_parser.py @@ -1,3 +1,5 @@ +import string + import pyparsing import pyparsing as pp from pyparsing import original_text_for @@ -200,8 +202,8 @@ class PromptParser(): elif type(node) is Fragment: results += [Fragment(node.text, node.weight*weight_scale)] elif type(node) is CrossAttentionControlSubstitute: - original = flatten_internal(node.original, weight_scale, [], ' CAo ') - edited = flatten_internal(node.edited, weight_scale, [], ' CAe ') + original = flatten_internal(node.original, weight_scale, [], prefix + ' CAo ') + edited = flatten_internal(node.edited, weight_scale, [], prefix + ' CAe ') results += [CrossAttentionControlSubstitute(original, edited)] elif type(node) is Blend: flattened_subprompts = [] @@ -236,24 +238,46 @@ class PromptParser(): lparen = pp.Literal("(").suppress() rparen = pp.Literal(")").suppress() + quotes = pp.Literal('"').suppress() + # accepts int or float notation, always maps to float number = pyparsing.pyparsing_common.real | pp.Word(pp.nums).set_parse_action(pp.token_map(float)) - SPACE_CHARS = ' \t\n' - - prompt_part = pp.Forward() - word = pp.Word(pp.printables).set_parse_action(lambda x: Fragment(' '.join([s for s in x]))) - word.set_name("word") - word.set_debug(False) + SPACE_CHARS = string.whitespace def make_fragment(x): #print("### making fragment for", x) if type(x) is str: return Fragment(x) elif type(x) is pp.ParseResults or type(x) is list: + #print(f'converting {x} to Fragment') return Fragment(' '.join([s for s in x])) else: raise PromptParser.ParsingException("Cannot make fragment from " + str(x)) + def parse_fragment_str(x): + return make_fragment(x) + + quoted_fragment = pp.QuotedString(quote_char='"', esc_char='\\') + quoted_fragment.set_parse_action(parse_fragment_str).set_name('quoted_fragment') + + unquoted_fragment = pp.Combine(pp.OneOrMore( + pp.Literal('\\"').set_debug(False) | + pp.Literal('\\').set_debug(False) | + pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"') + )) + unquoted_fragment.set_parse_action(parse_fragment_str).set_name('unquoted_fragment') + + parenthesized_fragment = \ + (lparen + quoted_fragment.set_debug(True) + rparen).set_name('quoted_paren_internal') | \ + (lparen + rparen).set_parse_action(lambda x: make_fragment('')) | \ + (lparen + pp.Combine(pp.OneOrMore( + pp.Literal('\\)').set_debug(False) | + pp.Literal('\\').set_debug(False) | + pp.Word(pp.printables, exclude_chars=string.whitespace + '\\)') | + pp.Word(string.whitespace) + )) + rparen).set_parse_action(parse_fragment_str).set_name('unquoted_paren_internal').set_debug(True) + parenthesized_fragment.set_name('parenthesized_fragment').set_debug(True) + # attention control of the form +(phrase) / -(phrase) / (phrase) # phrase can be multiple words, can have multiple +/- signs to increase the effect or type a floating point or integer weight attention = pp.Forward() @@ -303,41 +327,35 @@ class PromptParser(): pp.Literal('""').suppress() | (lparen + pp.Literal('""').suppress() + rparen) ).set_parse_action(lambda x: Fragment("")) + empty_string.set_name('empty_string') - original_words = ( - (lparen + pp.Literal('"').suppress() + (pp.OneOrMore(attention) | pp.CharsNotIn('"').set_parse_action(make_fragment)) + pp.Literal('"').suppress() + rparen).set_name('term1').set_debug(False) | - (pp.Literal('"').suppress() + (pp.OneOrMore(attention) | pp.CharsNotIn('"').set_parse_action(make_fragment)) + pp.Literal('"').suppress()).set_name('term2').set_debug(False) | - (lparen + (pp.OneOrMore(attention) | pp.CharsNotIn(')').set_parse_action(make_fragment)) + rparen).set_name('term3').set_debug(False) - ).set_name('original_words') - edited_words = ( - (pp.Literal('"').suppress() + (pp.OneOrMore(attention) | pp.CharsNotIn('"').set_parse_action(make_fragment)) + pp.Literal('"').suppress()).set_name('termA').set_debug(False) | - pp.Literal('""').suppress().set_parse_action(lambda x: Fragment("")) | - (pp.OneOrMore(attention) | pp.CharsNotIn(')').set_parse_action(make_fragment)).set_name('termB').set_debug(True) - ).set_name('edited_words') - cross_attention_substitute = (empty_string | original_words) + \ - pp.Literal(".swap").suppress() + \ - (empty_string | (lparen + edited_words + rparen) - ) - cross_attention_substitute.set_name('cross_attention_substitute') + original_fragment = empty_string | quoted_fragment | parenthesized_fragment | unquoted_fragment + edited_fragment = parenthesized_fragment + cross_attention_substitute = original_fragment + pp.Literal(".swap").suppress() + edited_fragment + + cross_attention_substitute.set_name('cross_attention_substitute').set_debug(True) def make_cross_attention_substitute(x): - #print("making cacs for", x) + print("making cacs for", x) cacs = CrossAttentionControlSubstitute(x[0], x[1]) - #print("made", cacs) + print("made", cacs) return cacs - cross_attention_substitute.set_parse_action(make_cross_attention_substitute) # simple fragments of text - prompt_part << (cross_attention_substitute - | attention - | word - ) + prompt_part = ( + cross_attention_substitute + | attention + | quoted_fragment + | unquoted_fragment + ) prompt_part.set_debug(False) prompt_part.set_name("prompt_part") + empty = ((lparen + rparen) | (quotes + quotes)).suppress() + # root prompt definition - prompt = pp.Group(pp.OneOrMore(prompt_part))\ + prompt = (pp.Group(pp.OneOrMore(prompt_part) | empty) + pp.StringEnd()) \ .set_parse_action(lambda x: Prompt(x[0])) # weighted blend of prompts diff --git a/tests/test_prompt_parser.py b/tests/test_prompt_parser.py index 99f4db33a1..0aa0cfd6ae 100644 --- a/tests/test_prompt_parser.py +++ b/tests/test_prompt_parser.py @@ -201,6 +201,9 @@ class PromptParserTestCase(unittest.TestCase): Fragment(',', 1), Fragment('fire', 2.0)])]) self.assertEqual(flames_to_trees_fire, parse_prompt('"0.5(fire 0.5(flames))".swap("0.7(trees) houses"), 2.0(fire)')) + def test_single(self): + print(parse_prompt('fire (trees and houses).swap("flames")')) + if __name__ == '__main__': unittest.main()