prompt parsing is now much more robust

This commit is contained in:
Damian at mba
2022-10-20 21:05:36 +02:00
parent 79b4afeae7
commit 3f13dd3ae8
2 changed files with 220 additions and 104 deletions

View File

@ -9,8 +9,8 @@ class Prompt():
def __init__(self, parts: list):
for c in parts:
if type(c) is not Attention and not issubclass(type(c), BaseFragment):
raise PromptParser.ParsingException(f"Prompt cannot contain {type(c)}, only {BaseFragment.__subclasses__()} are allowed")
if type(c) is not Attention and not issubclass(type(c), BaseFragment) and type(c) is not pp.ParseResults:
raise PromptParser.ParsingException(f"Prompt cannot contain {type(c).__name__} {c}, only {BaseFragment.__subclasses__()} are allowed")
self.children = parts
def __repr__(self):
return f"Prompt:{self.children}"
@ -48,6 +48,9 @@ class BaseFragment:
class Fragment(BaseFragment):
def __init__(self, text: str, weight: float=1):
assert(type(text) is str)
if '\\"' in text or '\\(' in text or '\\)' in text:
#print("Fragment converting escaped \( \) \\\" into ( ) \"")
text = text.replace('\\(', '(').replace('\\)', ')').replace('\\"', '"')
self.text = text
self.weight = float(weight)
@ -152,8 +155,10 @@ class PromptParser():
def parse(self, prompt: str) -> Conjunction:
'''
This parser is *very* forgiving. If it cannot parse syntax, it will return strings as-is to be passed on to the
diffusion.
:param prompt: The prompt string to parse
:return: a tuple
:return: a Conjunction representing the parsed results.
'''
#print(f"!!parsing '{prompt}'")
@ -169,7 +174,7 @@ class PromptParser():
def flatten(self, root: Conjunction):
print("flattening", root)
#print("flattening", root)
def fuse_fragments(items):
# print("fusing fragments in ", items)
@ -196,13 +201,13 @@ class PromptParser():
#print(prefix + "flattening", node, "...")
if type(node) is pp.ParseResults:
for x in node:
results = flatten_internal(x, weight_scale, results, prefix+'pr')
results = flatten_internal(x, weight_scale, results, prefix+' pr ')
#print(prefix, " ParseResults expanded, results is now", results)
elif type(node) is Attention:
# if node.weight < 1:
# todo: inject a blend when flattening attention with weight <1"
for c in node.children:
results = flatten_internal(c, weight_scale * node.weight, results, prefix + ' ')
for index,c in enumerate(node.children):
results = flatten_internal(c, weight_scale * node.weight, results, prefix + f" att{index} ")
elif type(node) is Fragment:
results += [Fragment(node.text, node.weight*weight_scale)]
elif type(node) is CrossAttentionControlSubstitute:
@ -225,7 +230,7 @@ class PromptParser():
#print(prefix + "after flattening Prompt, results is", results)
else:
raise PromptParser.ParsingException(f"unhandled node type {type(node)} when flattening {node}")
print(prefix + "-> after flattening", type(node).__name__, "results is", results)
#print(prefix + "-> after flattening", type(node).__name__, "results is", results)
return results
@ -246,6 +251,7 @@ class PromptParser():
# accepts int or float notation, always maps to float
number = pyparsing.pyparsing_common.real | pp.Word(pp.nums).set_parse_action(pp.token_map(float))
SPACE_CHARS = string.whitespace
greedy_word = pp.Word(pp.printables, exclude_chars=string.whitespace).set_name('greedy_word')
attention = pp.Forward()
@ -254,7 +260,7 @@ class PromptParser():
if type(x) is str:
return Fragment(x)
elif type(x) is pp.ParseResults or type(x) is list:
#print(f'converting {x} to Fragment')
#print(f'converting {type(x).__name__} to Fragment')
return Fragment(' '.join([s for s in x]))
else:
raise PromptParser.ParsingException("Cannot make fragment from " + str(x))
@ -264,52 +270,72 @@ class PromptParser():
parenthesized_fragment = pp.Forward()
def parse_fragment_str(x):
print("parsing", x)
#print("parsing fragment string", x)
if len(x[0].strip()) == 0:
return Fragment('')
fragment_parser = pp.Group(pp.OneOrMore(attention | pp.Word(pp.printables, exclude_chars=string.whitespace).set_parse_action(make_fragment)))
fragment_parser = pp.Group(pp.OneOrMore(attention | (greedy_word.set_parse_action(make_fragment))))
fragment_parser.set_name('word_or_attention')
result = fragment_parser.parse_string(x[0])
#result = (pp.OneOrMore(attention | unquoted_fragment) + pp.StringEnd()).parse_string(x[0])
print("parsed to", result)
#print("parsed to", result)
return result
quoted_fragment << pp.QuotedString(quote_char='"', esc_char='\\')
quoted_fragment.set_parse_action(make_fragment).set_name('quoted_fragment')
quoted_fragment.set_parse_action(parse_fragment_str).set_name('quoted_fragment')
self_unescaping_escaped_quote = pp.Literal('\\"').set_parse_action(lambda x: '"')
self_unescaping_escaped_lparen = pp.Literal('\\(').set_parse_action(lambda x: '(')
self_unescaping_escaped_rparen = pp.Literal('\\)').set_parse_action(lambda x: ')')
unquoted_fragment << pp.Combine(pp.OneOrMore(
pp.Literal('\\"').set_debug(False) |
pp.Literal('\\').set_debug(False) |
pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"')
self_unescaping_escaped_rparen | self_unescaping_escaped_lparen | self_unescaping_escaped_quote |
pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"()')
))
unquoted_fragment.set_parse_action(make_fragment).set_name('unquoted_fragment')
parenthesized_fragment << pp.Or([
(lparen + quoted_fragment.set_parse_action(parse_fragment_str).set_debug(True) + rparen).set_name('-quoted_paren_internal').set_debug(True),
(lparen + rparen).set_parse_action(lambda x: make_fragment('')).set_name('-()').set_debug(True),
parenthesized_fragment << pp.MatchFirst([
(lparen + quoted_fragment.copy().set_parse_action(parse_fragment_str).set_debug(False) + rparen).set_name('-quoted_paren_internal').set_debug(False),
(lparen + rparen).set_parse_action(lambda x: make_fragment('')).set_name('-()').set_debug(False),
(lparen + pp.Combine(pp.OneOrMore(
pp.Literal('\\)').set_debug(False) |
pp.Literal('\\').set_debug(False) |
pp.Word(pp.printables, exclude_chars=string.whitespace + '\\)') |
pp.Literal('\\"').set_debug(False).set_parse_action(lambda x: '"') |
pp.Literal('\\(').set_debug(False).set_parse_action(lambda x: '(') |
pp.Literal('\\)').set_debug(False).set_parse_action(lambda x: ')') |
pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"()') |
pp.Word(string.whitespace)
)).set_name('--combined').set_parse_action(parse_fragment_str).set_debug(True) + rparen)]).set_name('-unquoted_paren_internal').set_debug(True)
parenthesized_fragment.set_name('parenthesized_fragment').set_debug(True)
)).set_name('--combined').set_parse_action(parse_fragment_str).set_debug(False) + rparen)]).set_name('-unquoted_paren_internal').set_debug(False)
parenthesized_fragment.set_name('parenthesized_fragment').set_debug(False)
debug_attention = False
# attention control of the form +(phrase) / -(phrase) / <weight>(phrase)
# phrase can be multiple words, can have multiple +/- signs to increase the effect or type a floating point or integer weight
attention_head = (number | pp.Word('+') | pp.Word('-'))\
.set_name("attention_head")\
.set_debug(False)
fragment_inside_attention = pp.CharsNotIn(SPACE_CHARS+'()')\
.set_parse_action(make_fragment)\
.set_name("fragment_inside_attention")\
.set_debug(False)
word_inside_attention = pp.Combine(pp.OneOrMore(
pp.Literal('\\)') | pp.Literal('\\(') | pp.Literal('\\"') |
pp.Word(pp.printables, exclude_chars=string.whitespace + '\\()"')
)).set_name('word_inside_attention')
attention_with_parens = pp.Forward()
attention_with_parens_body = pp.nested_expr(content=pp.delimited_list((attention_with_parens | fragment_inside_attention), delim=SPACE_CHARS))
attention_with_parens_delimited_list = pp.delimited_list(pp.Or([
quoted_fragment.copy().set_debug(debug_attention),
attention.copy().set_debug(debug_attention),
word_inside_attention.set_debug(debug_attention)]).set_name('delim_inner').set_debug(debug_attention),
delim=string.whitespace)
# have to disable ignore_expr here to prevent pyparsing from stripping off quote marks
attention_with_parens_body = pp.nested_expr(content=attention_with_parens_delimited_list,
ignore_expr=None#((pp.Literal("\\(") | pp.Literal('\\)')))
)
attention_with_parens_body.set_debug(debug_attention)
attention_with_parens << (attention_head + attention_with_parens_body)
attention_with_parens.set_name('attention_with_parens').set_debug(debug_attention)
attention_without_parens = (pp.Word('+') | pp.Word('-')) + (quoted_fragment | word_inside_attention)
attention_without_parens.set_name('attention_without_parens').set_debug(debug_attention)
attention << (attention_with_parens | attention_without_parens)
def make_attention(x):
# print("making Attention from parsing with args", x0, x1)
#print("making Attention from", x)
weight = 1
# number(str)
if type(x[0]) is float or type(x[0]) is int:
@ -318,26 +344,17 @@ class PromptParser():
elif type(x[0]) is str:
base = self.attention_plus_base if x[0][0] == '+' else self.attention_minus_base
weight = pow(base, len(x[0]))
# print("Making attention with children of type", [str(type(x)) for x in x1])
return Attention(weight=weight, children=x[1])
if type(x[1]) is list or type(x[1]) is pp.ParseResults:
return Attention(weight=weight, children=[(Fragment(x) if type(x) is str else x) for x in x[1]])
elif type(x[1]) is str:
return Attention(weight=weight, children=[Fragment(x[1])])
elif type(x[1]) is Fragment:
return Attention(weight=weight, children=[x[1]])
raise PromptParser.ParsingException(f"Don't know how to make attention with children {x[1]}")
attention_with_parens.set_parse_action(make_attention)\
.set_name("attention_with_parens")\
.set_debug(False)
# attention control of the form ++word --word (no parens)
attention_without_parens = (
(pp.Word('+') | pp.Word('-')) +
pp.CharsNotIn(SPACE_CHARS+'()').set_parse_action(lambda x: [[make_fragment(x)]])
)\
.set_name("attention_without_parens")\
.set_debug(False)
attention_with_parens.set_parse_action(make_attention)
attention_without_parens.set_parse_action(make_attention)
attention << (attention_with_parens | attention_without_parens)\
.set_name("attention")\
.set_debug(False)
# cross-attention control
empty_string = ((lparen + rparen) |
pp.Literal('""').suppress() |
@ -345,26 +362,38 @@ class PromptParser():
).set_parse_action(lambda x: Fragment(""))
empty_string.set_name('empty_string')
original_fragment = empty_string | quoted_fragment | parenthesized_fragment | unquoted_fragment
# cross attention control
debug_cross_attention_control = False
original_fragment = pp.Or([empty_string.set_debug(debug_cross_attention_control),
quoted_fragment.set_debug(debug_cross_attention_control),
parenthesized_fragment.set_debug(debug_cross_attention_control),
unquoted_fragment.set_debug(debug_cross_attention_control)])
edited_fragment = parenthesized_fragment
cross_attention_substitute = original_fragment + pp.Literal(".swap").suppress() + edited_fragment
cross_attention_substitute.set_name('cross_attention_substitute').set_debug(True)
original_fragment.set_name('original_fragment').set_debug(debug_cross_attention_control)
edited_fragment.set_name('edited_fragment').set_debug(debug_cross_attention_control)
cross_attention_substitute.set_name('cross_attention_substitute').set_debug(debug_cross_attention_control)
def make_cross_attention_substitute(x):
print("making cacs for", x)
#print("making cacs for", x)
cacs = CrossAttentionControlSubstitute(x[0], x[1])
print("made", cacs)
#print("made", cacs)
return cacs
cross_attention_substitute.set_parse_action(make_cross_attention_substitute)
# simple fragments of text
prompt_part = (
cross_attention_substitute
| attention
| quoted_fragment
| unquoted_fragment
)
# use Or to match the longest
prompt_part = pp.Or([
cross_attention_substitute,
attention,
quoted_fragment,
unquoted_fragment,
lparen + unquoted_fragment + rparen # matches case where user has +(term) and just deletes the +
])
prompt_part.set_debug(False)
prompt_part.set_name("prompt_part")
@ -373,8 +402,10 @@ class PromptParser():
(quotes + pp.ZeroOrMore(pp.Word(string.whitespace)) + quotes)).set_debug(False).set_name('empty')
# root prompt definition
prompt = (pp.Group(pp.OneOrMore(prompt_part) | empty) + pp.StringEnd()) \
.set_parse_action(lambda x: Prompt(x[0]))
prompt = ((pp.OneOrMore(prompt_part) | empty) + pp.StringEnd()) \
.set_parse_action(lambda x: Prompt(x))
# weighted blend of prompts
# ("promptA", "promptB").blend(a, b) where "promptA" and "promptB" are valid prompts and a and b are float or
@ -418,7 +449,7 @@ class PromptParser():
return Conjunction(parts, weights)
conjunction_with_parens_and_quotes.set_parse_action(make_conjunction)
implicit_conjunction = pp.OneOrMore(blend | prompt)
implicit_conjunction = pp.OneOrMore(blend | prompt).set_name('implicit_conjunction')
implicit_conjunction.set_parse_action(lambda x: Conjunction(x))
conjunction = conjunction_with_parens_and_quotes | implicit_conjunction