wip re-writing parts of prompt parser

This commit is contained in:
Damian at mba 2022-10-20 15:56:46 +02:00
parent c9d27634b4
commit da223dfe81
2 changed files with 52 additions and 31 deletions

View File

@ -1,3 +1,5 @@
import string
import pyparsing
import pyparsing as pp
from pyparsing import original_text_for
@ -200,8 +202,8 @@ class PromptParser():
elif type(node) is Fragment:
results += [Fragment(node.text, node.weight*weight_scale)]
elif type(node) is CrossAttentionControlSubstitute:
original = flatten_internal(node.original, weight_scale, [], ' CAo ')
edited = flatten_internal(node.edited, weight_scale, [], ' CAe ')
original = flatten_internal(node.original, weight_scale, [], prefix + ' CAo ')
edited = flatten_internal(node.edited, weight_scale, [], prefix + ' CAe ')
results += [CrossAttentionControlSubstitute(original, edited)]
elif type(node) is Blend:
flattened_subprompts = []
@ -236,24 +238,46 @@ class PromptParser():
lparen = pp.Literal("(").suppress()
rparen = pp.Literal(")").suppress()
quotes = pp.Literal('"').suppress()
# accepts int or float notation, always maps to float
number = pyparsing.pyparsing_common.real | pp.Word(pp.nums).set_parse_action(pp.token_map(float))
SPACE_CHARS = ' \t\n'
prompt_part = pp.Forward()
word = pp.Word(pp.printables).set_parse_action(lambda x: Fragment(' '.join([s for s in x])))
word.set_name("word")
word.set_debug(False)
SPACE_CHARS = string.whitespace
def make_fragment(x):
#print("### making fragment for", x)
if type(x) is str:
return Fragment(x)
elif type(x) is pp.ParseResults or type(x) is list:
#print(f'converting {x} to Fragment')
return Fragment(' '.join([s for s in x]))
else:
raise PromptParser.ParsingException("Cannot make fragment from " + str(x))
def parse_fragment_str(x):
return make_fragment(x)
quoted_fragment = pp.QuotedString(quote_char='"', esc_char='\\')
quoted_fragment.set_parse_action(parse_fragment_str).set_name('quoted_fragment')
unquoted_fragment = pp.Combine(pp.OneOrMore(
pp.Literal('\\"').set_debug(False) |
pp.Literal('\\').set_debug(False) |
pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"')
))
unquoted_fragment.set_parse_action(parse_fragment_str).set_name('unquoted_fragment')
parenthesized_fragment = \
(lparen + quoted_fragment.set_debug(True) + rparen).set_name('quoted_paren_internal') | \
(lparen + rparen).set_parse_action(lambda x: make_fragment('')) | \
(lparen + pp.Combine(pp.OneOrMore(
pp.Literal('\\)').set_debug(False) |
pp.Literal('\\').set_debug(False) |
pp.Word(pp.printables, exclude_chars=string.whitespace + '\\)') |
pp.Word(string.whitespace)
)) + rparen).set_parse_action(parse_fragment_str).set_name('unquoted_paren_internal').set_debug(True)
parenthesized_fragment.set_name('parenthesized_fragment').set_debug(True)
# attention control of the form +(phrase) / -(phrase) / <weight>(phrase)
# phrase can be multiple words, can have multiple +/- signs to increase the effect or type a floating point or integer weight
attention = pp.Forward()
@ -303,41 +327,35 @@ class PromptParser():
pp.Literal('""').suppress() |
(lparen + pp.Literal('""').suppress() + rparen)
).set_parse_action(lambda x: Fragment(""))
empty_string.set_name('empty_string')
original_words = (
(lparen + pp.Literal('"').suppress() + (pp.OneOrMore(attention) | pp.CharsNotIn('"').set_parse_action(make_fragment)) + pp.Literal('"').suppress() + rparen).set_name('term1').set_debug(False) |
(pp.Literal('"').suppress() + (pp.OneOrMore(attention) | pp.CharsNotIn('"').set_parse_action(make_fragment)) + pp.Literal('"').suppress()).set_name('term2').set_debug(False) |
(lparen + (pp.OneOrMore(attention) | pp.CharsNotIn(')').set_parse_action(make_fragment)) + rparen).set_name('term3').set_debug(False)
).set_name('original_words')
edited_words = (
(pp.Literal('"').suppress() + (pp.OneOrMore(attention) | pp.CharsNotIn('"').set_parse_action(make_fragment)) + pp.Literal('"').suppress()).set_name('termA').set_debug(False) |
pp.Literal('""').suppress().set_parse_action(lambda x: Fragment("")) |
(pp.OneOrMore(attention) | pp.CharsNotIn(')').set_parse_action(make_fragment)).set_name('termB').set_debug(True)
).set_name('edited_words')
cross_attention_substitute = (empty_string | original_words) + \
pp.Literal(".swap").suppress() + \
(empty_string | (lparen + edited_words + rparen)
)
cross_attention_substitute.set_name('cross_attention_substitute')
original_fragment = empty_string | quoted_fragment | parenthesized_fragment | unquoted_fragment
edited_fragment = parenthesized_fragment
cross_attention_substitute = original_fragment + pp.Literal(".swap").suppress() + edited_fragment
cross_attention_substitute.set_name('cross_attention_substitute').set_debug(True)
def make_cross_attention_substitute(x):
#print("making cacs for", x)
print("making cacs for", x)
cacs = CrossAttentionControlSubstitute(x[0], x[1])
#print("made", cacs)
print("made", cacs)
return cacs
cross_attention_substitute.set_parse_action(make_cross_attention_substitute)
# simple fragments of text
prompt_part << (cross_attention_substitute
| attention
| word
)
prompt_part = (
cross_attention_substitute
| attention
| quoted_fragment
| unquoted_fragment
)
prompt_part.set_debug(False)
prompt_part.set_name("prompt_part")
empty = ((lparen + rparen) | (quotes + quotes)).suppress()
# root prompt definition
prompt = pp.Group(pp.OneOrMore(prompt_part))\
prompt = (pp.Group(pp.OneOrMore(prompt_part) | empty) + pp.StringEnd()) \
.set_parse_action(lambda x: Prompt(x[0]))
# weighted blend of prompts

View File

@ -201,6 +201,9 @@ class PromptParserTestCase(unittest.TestCase):
Fragment(',', 1), Fragment('fire', 2.0)])])
self.assertEqual(flames_to_trees_fire, parse_prompt('"0.5(fire 0.5(flames))".swap("0.7(trees) houses"), 2.0(fire)'))
def test_single(self):
print(parse_prompt('fire (trees and houses).swap("flames")'))
if __name__ == '__main__':
unittest.main()