Rebuilt prompt parsing logic

Complete re-write of the prompt parsing logic to be more readable and
logical, and therefore also hopefully easier to debug, maintain, and
augment.

In the process it has also become more robust to badly-formed prompts.

Squashed commit of the following:

commit 8fcfa88a16e1390d41717e940d72aed64712171c
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Sun Oct 30 17:05:57 2022 +0100

    further cleanup

commit 1a1fd78bcfeb49d072e3e6d5808aa8df15441629
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Sun Oct 30 16:07:57 2022 +0100

    cleanup and document

commit 099c9659fa8b8135876f9a5a50fe80b20bc0635c
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Sun Oct 30 15:54:58 2022 +0100

    works fully

commit 5e6887ea8c25a1e21438ff6defb381fd027d25fd
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Sun Oct 30 15:24:31 2022 +0100

    further...

commit 492fda120844d9bc1ad4ec7dd408a3374762d0ff
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Sun Oct 30 14:08:57 2022 +0100

    getting there...

commit c6aab05a8450cc3c95c8691daf38fdc64c74f52d
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Fri Oct 28 14:29:03 2022 +0200

    wip doesn't compile

commit 5e533f731cfd20cd435330eeb0012e5689e87e81
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Fri Oct 28 13:21:43 2022 +0200

    working with CrossAttentionCtonrol but no Attention support yet

commit 9678348773431e500e110e8aede99086bb7b5955
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Fri Oct 28 13:04:52 2022 +0200

    wip rebuiling prompt parser
This commit is contained in:
Damian at mba 2022-11-01 10:08:42 +01:00 committed by Lincoln Stein
parent 349cc25433
commit e554c2607f
2 changed files with 314 additions and 293 deletions

View File

@ -28,7 +28,7 @@ class Prompt():
def __init__(self, parts: list): def __init__(self, parts: list):
for c in parts: for c in parts:
if type(c) is not Attention and not issubclass(type(c), BaseFragment) and type(c) is not pp.ParseResults: if type(c) is not Attention and not issubclass(type(c), BaseFragment) and type(c) is not pp.ParseResults:
raise PromptParser.ParsingException(f"Prompt cannot contain {type(c).__name__} {c}, only {BaseFragment.__subclasses__()} are allowed") raise PromptParser.ParsingException(f"Prompt cannot contain {type(c).__name__} ({c}), only {[c.__name__ for c in BaseFragment.__subclasses__()]} are allowed")
self.children = parts self.children = parts
def __repr__(self): def __repr__(self):
return f"Prompt:{self.children}" return f"Prompt:{self.children}"
@ -102,12 +102,18 @@ class Attention():
Do not traverse directly; instead obtain a FlattenedPrompt by calling Flatten() on a top-level Conjunction object. Do not traverse directly; instead obtain a FlattenedPrompt by calling Flatten() on a top-level Conjunction object.
""" """
def __init__(self, weight: float, children: list): def __init__(self, weight: float, children: list):
if type(weight) is not float:
raise PromptParser.ParsingException(
f"Attention weight must be float (got {type(weight).__name__} {weight})")
self.weight = weight self.weight = weight
if type(children) is not list:
raise PromptParser.ParsingException(f"cannot make Attention with non-list of children (got {type(children)})")
assert(type(children) is list)
self.children = children self.children = children
#print(f"A: requested attention '{children}' to {weight}") #print(f"A: requested attention '{children}' to {weight}")
def __repr__(self): def __repr__(self):
return f"Attention:'{self.children}' @ {self.weight}" return f"Attention:{self.children} * {self.weight}"
def __eq__(self, other): def __eq__(self, other):
return type(other) is Attention and other.weight == self.weight and other.fragment == self.fragment return type(other) is Attention and other.weight == self.weight and other.fragment == self.fragment
@ -136,9 +142,9 @@ class CrossAttentionControlSubstitute(CrossAttentionControlledFragment):
Fragment('sitting on a car') Fragment('sitting on a car')
]) ])
""" """
def __init__(self, original: Union[Fragment, list], edited: Union[Fragment, list], options: dict=None): def __init__(self, original: list, edited: list, options: dict=None):
self.original = original self.original = original
self.edited = edited self.edited = edited if len(edited)>0 else [Fragment('')]
default_options = { default_options = {
's_start': 0.0, 's_start': 0.0,
@ -190,12 +196,12 @@ class Conjunction():
""" """
def __init__(self, prompts: list, weights: list = None): def __init__(self, prompts: list, weights: list = None):
# force everything to be a Prompt # force everything to be a Prompt
#print("making conjunction with", parts) #print("making conjunction with", prompts, "types", [type(p).__name__ for p in prompts])
self.prompts = [x if (type(x) is Prompt self.prompts = [x if (type(x) is Prompt
or type(x) is Blend or type(x) is Blend
or type(x) is FlattenedPrompt) or type(x) is FlattenedPrompt)
else Prompt(x) for x in prompts] else Prompt(x) for x in prompts]
self.weights = [1.0]*len(self.prompts) if weights is None else list(weights) self.weights = [1.0]*len(self.prompts) if (weights is None or len(weights)==0) else list(weights)
if len(self.weights) != len(self.prompts): if len(self.weights) != len(self.prompts):
raise PromptParser.ParsingException(f"while parsing Conjunction: mismatched parts/weights counts {prompts}, {weights}") raise PromptParser.ParsingException(f"while parsing Conjunction: mismatched parts/weights counts {prompts}, {weights}")
self.type = 'AND' self.type = 'AND'
@ -216,6 +222,7 @@ class Blend():
""" """
def __init__(self, prompts: list, weights: list[float], normalize_weights: bool=True): def __init__(self, prompts: list, weights: list[float], normalize_weights: bool=True):
#print("making Blend with prompts", prompts, "and weights", weights) #print("making Blend with prompts", prompts, "and weights", weights)
weights = [1.0]*len(prompts) if (weights is None or len(weights)==0) else list(weights)
if len(prompts) != len(weights): if len(prompts) != len(weights):
raise PromptParser.ParsingException(f"while parsing Blend: mismatched prompts/weights counts {prompts}, {weights}") raise PromptParser.ParsingException(f"while parsing Blend: mismatched prompts/weights counts {prompts}, {weights}")
for p in prompts: for p in prompts:
@ -244,6 +251,10 @@ class PromptParser():
class ParsingException(Exception): class ParsingException(Exception):
pass pass
class UnrecognizedOperatorException(Exception):
def __init__(self, operator:str):
super().__init__("Unrecognized operator: " + operator)
def __init__(self, attention_plus_base=1.1, attention_minus_base=0.9): def __init__(self, attention_plus_base=1.1, attention_minus_base=0.9):
self.conjunction, self.prompt = build_parser_syntax(attention_plus_base, attention_minus_base) self.conjunction, self.prompt = build_parser_syntax(attention_plus_base, attention_minus_base)
@ -279,7 +290,7 @@ class PromptParser():
return Blend(prompts=flattened_prompts, weights=weights, normalize_weights=True) return Blend(prompts=flattened_prompts, weights=weights, normalize_weights=True)
def flatten(self, root: Conjunction) -> Conjunction: def flatten(self, root: Conjunction, verbose = False) -> Conjunction:
""" """
Flattening a Conjunction traverses all of the nested tree-like structures in each of its Prompts or Blends, Flattening a Conjunction traverses all of the nested tree-like structures in each of its Prompts or Blends,
producing from each of these walks a linear sequence of Fragment or CrossAttentionControlSubstitute objects producing from each of these walks a linear sequence of Fragment or CrossAttentionControlSubstitute objects
@ -289,8 +300,6 @@ class PromptParser():
:return: A Conjunction containing the result of flattening each of the prompts in the passed-in root. :return: A Conjunction containing the result of flattening each of the prompts in the passed-in root.
""" """
#print("flattening", root)
def fuse_fragments(items): def fuse_fragments(items):
# print("fusing fragments in ", items) # print("fusing fragments in ", items)
result = [] result = []
@ -313,8 +322,8 @@ class PromptParser():
return result return result
def flatten_internal(node, weight_scale, results, prefix): def flatten_internal(node, weight_scale, results, prefix):
#print(prefix + "flattening", node, "...") verbose and print(prefix + "flattening", node, "...")
if type(node) is pp.ParseResults: if type(node) is pp.ParseResults or type(node) is list:
for x in node: for x in node:
results = flatten_internal(x, weight_scale, results, prefix+' pr ') results = flatten_internal(x, weight_scale, results, prefix+' pr ')
#print(prefix, " ParseResults expanded, results is now", results) #print(prefix, " ParseResults expanded, results is now", results)
@ -345,67 +354,59 @@ class PromptParser():
#print(prefix + "after flattening Prompt, results is", results) #print(prefix + "after flattening Prompt, results is", results)
else: else:
raise PromptParser.ParsingException(f"unhandled node type {type(node)} when flattening {node}") raise PromptParser.ParsingException(f"unhandled node type {type(node)} when flattening {node}")
#print(prefix + "-> after flattening", type(node).__name__, "results is", results) verbose and print(prefix + "-> after flattening", type(node).__name__, "results is", results)
return results return results
verbose and print("flattening", root)
flattened_parts = [] flattened_parts = []
for part in root.prompts: for part in root.prompts:
flattened_parts += flatten_internal(part, 1.0, [], ' C| ') flattened_parts += flatten_internal(part, 1.0, [], ' C| ')
#print("flattened to", flattened_parts) verbose and print("flattened to", flattened_parts)
weights = root.weights weights = root.weights
return Conjunction(flattened_parts, weights) return Conjunction(flattened_parts, weights)
def build_parser_syntax(attention_plus_base: float, attention_minus_base: float): def build_parser_syntax(attention_plus_base: float, attention_minus_base: float):
def make_operator_object(x):
#print('making operator for', x)
target = x[0]
operator = x[1]
arguments = x[2]
if operator == '.attend':
weight_raw = arguments[0]
weight = 1.0
if type(weight_raw) is float or type(weight_raw) is int:
weight = weight_raw
elif type(weight_raw) is str:
base = attention_plus_base if weight_raw[0] == '+' else attention_minus_base
weight = pow(base, len(weight_raw))
return Attention(weight=weight, children=[x for x in x[0]])
elif operator == '.swap':
return CrossAttentionControlSubstitute(target, arguments, x.as_dict())
elif operator == '.blend':
prompts = [Prompt(p) for p in x[0]]
weights_raw = x[2]
normalize_weights = True
if len(weights_raw) > 0 and weights_raw[-1][0] == 'no_normalize':
normalize_weights = False
weights_raw = weights_raw[:-1]
weights = [float(w[0]) for w in weights_raw]
return Blend(prompts=prompts, weights=weights, normalize_weights=normalize_weights)
elif operator == '.and' or operator == '.add':
prompts = [Prompt(p) for p in x[0]]
weights = [float(w[0]) for w in x[2]]
return Conjunction(prompts=prompts, weights=weights)
lparen = pp.Literal("(").suppress() raise PromptParser.UnrecognizedOperatorException(operator)
rparen = pp.Literal(")").suppress()
quotes = pp.Literal('"').suppress()
comma = pp.Literal(",").suppress()
# accepts int or float notation, always maps to float def parse_fragment_str(x, expression: pp.ParseExpression, in_quotes: bool = False, in_parens: bool = False):
number = pp.pyparsing_common.real | \
pp.Combine(pp.Optional("-")+pp.Word(pp.nums)).set_parse_action(pp.token_map(float))
attention = pp.Forward()
quoted_fragment = pp.Forward()
parenthesized_fragment = pp.Forward()
cross_attention_substitute = pp.Forward()
def make_text_fragment(x):
#print("### making fragment for", x)
if type(x[0]) is Fragment:
assert(False)
if type(x) is str:
return Fragment(x)
elif type(x) is pp.ParseResults or type(x) is list:
#print(f'converting {type(x).__name__} to Fragment')
return Fragment(' '.join([s for s in x]))
else:
raise PromptParser.ParsingException("Cannot make fragment from " + str(x))
def build_escaped_word_parser_charbychar(escaped_chars_to_ignore: str):
escapes = []
for c in escaped_chars_to_ignore:
escapes.append(pp.Literal('\\'+c))
return pp.Combine(pp.OneOrMore(
pp.MatchFirst(escapes + [pp.CharsNotIn(
string.whitespace + escaped_chars_to_ignore,
exact=1
)])
))
def parse_fragment_str(x, in_quotes: bool=False, in_parens: bool=False):
#print(f"parsing fragment string for {x}") #print(f"parsing fragment string for {x}")
fragment_string = x[0] fragment_string = x[0]
#print(f"ppparsing fragment string \"{fragment_string}\"")
if len(fragment_string.strip()) == 0: if len(fragment_string.strip()) == 0:
return Fragment('') return Fragment('')
@ -413,234 +414,198 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float)
# escape unescaped quotes # escape unescaped quotes
fragment_string = fragment_string.replace('"', '\\"') fragment_string = fragment_string.replace('"', '\\"')
#fragment_parser = pp.Group(pp.OneOrMore(attention | cross_attention_substitute | (greedy_word.set_parse_action(make_text_fragment))))
try: try:
result = pp.Group(pp.MatchFirst([ result = (expression + pp.StringEnd()).parse_string(fragment_string)
pp.OneOrMore(quoted_fragment | attention | unquoted_word).set_name('pf_str_qfuq'),
pp.Empty().set_parse_action(make_text_fragment) + pp.StringEnd()
])).set_name('blend-result').set_debug(False).parse_string(fragment_string)
#print("parsed to", result) #print("parsed to", result)
return result return result
except pp.ParseException as e: except pp.ParseException as e:
#print("parse_fragment_str couldn't parse prompt string:", e) #print("parse_fragment_str couldn't parse prompt string:", e)
raise raise
# meaningful symbols
lparen = pp.Literal("(").suppress()
rparen = pp.Literal(")").suppress()
quote = pp.Literal('"').suppress()
comma = pp.Literal(",").suppress()
dot = pp.Literal(".").suppress()
equals = pp.Literal("=").suppress()
escaped_lparen = pp.Literal('\\(')
escaped_rparen = pp.Literal('\\)')
escaped_quote = pp.Literal('\\"')
escaped_comma = pp.Literal('\\,')
escaped_dot = pp.Literal('\\.')
escaped_plus = pp.Literal('\\+')
escaped_minus = pp.Literal('\\-')
escaped_equals = pp.Literal('\\=')
syntactic_symbols = {
'(': escaped_lparen,
')': escaped_rparen,
'"': escaped_quote,
',': escaped_comma,
'.': escaped_dot,
'+': escaped_plus,
'-': escaped_minus,
'=': escaped_equals,
}
syntactic_chars = "".join(syntactic_symbols.keys())
# accepts int or float notation, always maps to float
number = pp.pyparsing_common.real | \
pp.Combine(pp.Optional("-")+pp.Word(pp.nums)).set_parse_action(pp.token_map(float))
# for options
keyword = pp.Word(pp.alphanums + '_')
# a word that absolutely does not contain any meaningful syntax
non_syntax_word = pp.Combine(pp.OneOrMore(pp.MatchFirst([
pp.Or(syntactic_symbols.values()),
pp.one_of(['-', '+']) + pp.NotAny(pp.White() | pp.Char(syntactic_chars) | pp.StringEnd()),
# build character-by-character
pp.CharsNotIn(string.whitespace + syntactic_chars, exact=1)
])))
non_syntax_word.set_parse_action(lambda x: [Fragment(t) for t in x])
non_syntax_word.set_name('non_syntax_word')
non_syntax_word.set_debug(False)
# a word that can contain any character at all - greedily consumes syntax, so use with care
free_word = pp.CharsNotIn(string.whitespace).set_parse_action(lambda x: Fragment(x[0]))
free_word.set_name('free_word')
free_word.set_debug(False)
# ok here we go. forward declare some things..
attention = pp.Forward()
cross_attention_substitute = pp.Forward()
parenthesized_fragment = pp.Forward()
quoted_fragment = pp.Forward()
# the types of things that can go into a fragment, consisting of syntax-full and/or strictly syntax-free components
fragment_part_expressions = [
attention,
cross_attention_substitute,
parenthesized_fragment,
quoted_fragment,
non_syntax_word
]
# a fragment that is permitted to contain commas
fragment_including_commas = pp.ZeroOrMore(pp.MatchFirst(
fragment_part_expressions + [
pp.Literal(',').set_parse_action(lambda x: Fragment(x[0]))
]
))
# a fragment that is not permitted to contain commas
fragment_excluding_commas = pp.ZeroOrMore(pp.MatchFirst(
fragment_part_expressions
))
# a fragment in double quotes (may be nested)
quoted_fragment << pp.QuotedString(quote_char='"', esc_char=None, esc_quote='\\"') quoted_fragment << pp.QuotedString(quote_char='"', esc_char=None, esc_quote='\\"')
quoted_fragment.set_parse_action(lambda x: parse_fragment_str(x, in_quotes=True)).set_name('quoted_fragment') quoted_fragment.set_parse_action(lambda x: parse_fragment_str(x, fragment_including_commas, in_quotes=True))
escaped_quote = pp.Literal('\\"')#.set_parse_action(lambda x: '"') # a fragment inside parentheses (may be nested)
escaped_lparen = pp.Literal('\\(')#.set_parse_action(lambda x: '(') parenthesized_fragment << (lparen + fragment_including_commas + rparen)
escaped_rparen = pp.Literal('\\)')#.set_parse_action(lambda x: ')') parenthesized_fragment.set_name('parenthesized_fragment')
escaped_backslash = pp.Literal('\\\\')#.set_parse_action(lambda x: '"') parenthesized_fragment.set_debug(False)
empty = ( # a string of the form (<keyword>=<float|keyword> | <float> | <keyword>) where keyword is alphanumeric + '_'
(lparen + pp.ZeroOrMore(pp.Word(string.whitespace)) + rparen) | option = pp.Group(pp.MatchFirst([
(quotes + pp.ZeroOrMore(pp.Word(string.whitespace)) + quotes)).set_debug(False).set_name('empty') keyword + equals + (number | keyword), # option=value
number.copy().set_parse_action(pp.token_map(str)), # weight
keyword # flag
def not_ends_with_swap(x):
#print("trying to match:", x)
return not x[0].endswith('.swap')
unquoted_word = (pp.Combine(pp.OneOrMore(
escaped_rparen | escaped_lparen | escaped_quote | escaped_backslash |
(pp.CharsNotIn(string.whitespace + '\\"()', exact=1)
)))
# don't whitespace when the next word starts with +, eg "badly +formed"
+ (pp.White().suppress() |
# don't eat +/-
pp.NotAny(pp.Word('+') | pp.Word('-'))
)
)
unquoted_word.set_parse_action(make_text_fragment).set_name('unquoted_word').set_debug(False)
#print(unquoted_fragment.parse_string("cat.swap(dog)"))
parenthesized_fragment << (lparen +
pp.Or([
(parenthesized_fragment),
(quoted_fragment.copy().set_parse_action(lambda x: parse_fragment_str(x, in_quotes=True)).set_debug(False)).set_name('-quoted_paren_internal').set_debug(False),
(pp.Combine(pp.OneOrMore(
escaped_quote | escaped_lparen | escaped_rparen | escaped_backslash |
pp.CharsNotIn(string.whitespace + '\\"()', exact=1) |
pp.White()
)).set_name('--combined').set_parse_action(lambda x: parse_fragment_str(x, in_parens=True)).set_debug(False)),
pp.Empty()
]) + rparen)
parenthesized_fragment.set_name('parenthesized_fragment').set_debug(False)
debug_attention = False
# attention control of the form (phrase)+ / (phrase)+ / (phrase)<weight>
# phrase can be multiple words, can have multiple +/- signs to increase the effect or type a floating point or integer weight
attention_with_parens = pp.Forward()
attention_without_parens = pp.Forward()
attention_with_parens_foot = (number | pp.Word('+') | pp.Word('-'))\
.set_name("attention_foot")\
.set_debug(False)
attention_with_parens <<= pp.Group(
lparen +
pp.ZeroOrMore(quoted_fragment | attention_with_parens | parenthesized_fragment | cross_attention_substitute | attention_without_parens |
(pp.Empty() + build_escaped_word_parser_charbychar('()')).set_name('undecorated_word').set_debug(debug_attention)#.set_parse_action(lambda t: t[0])
)
+ rparen + attention_with_parens_foot)
attention_with_parens.set_name('attention_with_parens').set_debug(debug_attention)
attention_without_parens_foot = (pp.NotAny(pp.White()) + pp.Or([pp.Word('+'), pp.Word('-')]) + pp.FollowedBy(pp.StringEnd() | pp.White() | pp.Literal('(') | pp.Literal(')') | pp.Literal(',') | pp.Literal('"')) ).set_name('attention_without_parens_foots')
attention_without_parens <<= pp.Group(pp.MatchFirst([
quoted_fragment.copy().set_name('attention_quoted_fragment_without_parens').set_debug(debug_attention) + attention_without_parens_foot,
pp.Combine(build_escaped_word_parser_charbychar('()+-')).set_name('attention_word_without_parens').set_debug(debug_attention)#.set_parse_action(lambda x: print('escapéd', x))
+ attention_without_parens_foot#.leave_whitespace()
])) ]))
attention_without_parens.set_name('attention_without_parens').set_debug(debug_attention) # options for an operator, eg "s_start=0.1, 0.3, no_normalize"
options = pp.Dict(pp.Optional(pp.delimited_list(option)))
options.set_name('options')
options.set_debug(False)
# a fragment which can be used as the target for an operator - either quoted or in parentheses, or a bare vanilla word
potential_operator_target = (quoted_fragment | parenthesized_fragment | non_syntax_word)
attention << pp.MatchFirst([attention_with_parens, # a fragment whose weight has been increased or decreased by a given amount
attention_without_parens attention_weight_operator = pp.Word('+') | pp.Word('-') | number
]) attention_explicit = (
pp.Group(potential_operator_target)
+ pp.Literal('.attend')
+ lparen
+ pp.Group(attention_weight_operator)
+ rparen
)
attention_explicit.set_parse_action(make_operator_object)
attention_implicit = (
pp.Group(potential_operator_target)
+ pp.NotAny(pp.White()) # do not permit whitespace between term and operator
+ pp.Group(attention_weight_operator)
)
attention_implicit.set_parse_action(lambda x: make_operator_object([x[0], '.attend', x[1]]))
attention << (attention_explicit | attention_implicit)
attention.set_name('attention') attention.set_name('attention')
attention.set_debug(False)
def make_attention(x): # cross-attention control by swapping one fragment for another
#print("entered make_attention with", x) cross_attention_substitute << (
children = x[0][:-1] pp.Group(potential_operator_target).set_name('ca-target').set_debug(False)
weight_raw = x[0][-1] + pp.Literal(".swap").set_name('ca-operator').set_debug(False)
weight = 1.0 + lparen
if type(weight_raw) is float or type(weight_raw) is int: + pp.Group(fragment_excluding_commas).set_name('ca-replacement').set_debug(False)
weight = weight_raw + pp.Optional(comma + options).set_name('ca-options').set_debug(False)
elif type(weight_raw) is str: + rparen
base = attention_plus_base if weight_raw[0] == '+' else attention_minus_base )
weight = pow(base, len(weight_raw)) cross_attention_substitute.set_name('cross_attention_substitute')
cross_attention_substitute.set_debug(False)
#print("making Attention from", children, "with weight", weight) cross_attention_substitute.set_parse_action(make_operator_object)
return Attention(weight=weight, children=[(Fragment(x) if type(x) is str else x) for x in children])
attention_with_parens.set_parse_action(make_attention)
attention_without_parens.set_parse_action(make_attention)
#print("parsing test:", attention_with_parens.parse_string("mountain (man)1.1"))
# cross-attention control
empty_string = ((lparen + rparen) |
pp.Literal('""').suppress() |
(lparen + pp.Literal('""').suppress() + rparen)
).set_parse_action(lambda x: Fragment(""))
empty_string.set_name('empty_string')
# cross attention control
debug_cross_attention_control = False
original_fragment = pp.MatchFirst([
quoted_fragment.set_debug(debug_cross_attention_control),
parenthesized_fragment.set_debug(debug_cross_attention_control),
pp.Combine(pp.OneOrMore(pp.CharsNotIn(string.whitespace + '.', exact=1))).set_parse_action(make_text_fragment) + pp.FollowedBy(".swap"),
empty_string.set_debug(debug_cross_attention_control),
])
# support keyword=number arguments
cross_attention_option_keyword = pp.Or([pp.Keyword("s_start"), pp.Keyword("s_end"), pp.Keyword("t_start"), pp.Keyword("t_end"), pp.Keyword("shape_freedom")])
cross_attention_option = pp.Group(cross_attention_option_keyword + pp.Literal("=").suppress() + number)
edited_fragment = pp.MatchFirst([
(lparen + rparen).set_parse_action(lambda x: Fragment('')),
lparen +
(quoted_fragment | attention |
pp.Group(pp.ZeroOrMore(build_escaped_word_parser_charbychar(',)').set_parse_action(make_text_fragment)))
) +
pp.Dict(pp.ZeroOrMore(comma + cross_attention_option)) +
rparen,
parenthesized_fragment
])
cross_attention_substitute << original_fragment + pp.Literal(".swap").set_debug(False).suppress() + edited_fragment
original_fragment.set_name('original_fragment').set_debug(debug_cross_attention_control)
edited_fragment.set_name('edited_fragment').set_debug(debug_cross_attention_control)
cross_attention_substitute.set_name('cross_attention_substitute').set_debug(debug_cross_attention_control)
def make_cross_attention_substitute(x):
#print("making cacs for", x[0], "->", x[1], "with options", x.as_dict())
#if len(x>2):
cacs = CrossAttentionControlSubstitute(x[0], x[1], options=x.as_dict())
#print("made", cacs)
return cacs
cross_attention_substitute.set_parse_action(make_cross_attention_substitute)
# root prompt definition # an entire self-contained prompt, which can be used in a Blend or Conjunction
debug_root_prompt = False prompt = pp.ZeroOrMore(pp.MatchFirst([
prompt = (pp.OneOrMore(pp.MatchFirst([cross_attention_substitute.set_debug(debug_root_prompt), cross_attention_substitute,
attention.set_debug(debug_root_prompt), attention,
quoted_fragment.set_debug(debug_root_prompt), quoted_fragment,
parenthesized_fragment.set_debug(debug_root_prompt), parenthesized_fragment,
unquoted_word.set_debug(debug_root_prompt), free_word,
empty.set_parse_action(make_text_fragment).set_debug(debug_root_prompt)]) pp.White().suppress()
) + pp.StringEnd()) \ ]))
.set_name('prompt') \ quoted_prompt = quoted_fragment.copy().set_parse_action(lambda x: parse_fragment_str(x, prompt, in_quotes=True))
.set_parse_action(lambda x: Prompt(x)) \
.set_debug(debug_root_prompt)
#print("parsing test:", prompt.parse_string("spaced eyes--"))
#print("parsing test:", prompt.parse_string("eyes--"))
# weighted blend of prompts # a blend/lerp between the feature vectors for two or more prompts
# ("promptA", "promptB").blend(a, b) where "promptA" and "promptB" are valid prompts and a and b are float or blend = (
# int weights. lparen
# can specify more terms eg ("promptA", "promptB", "promptC").blend(a,b,c) + pp.Group(pp.delimited_list(pp.Group(potential_operator_target | quoted_prompt), min=1)).set_name('bl-target').set_debug(False)
+ rparen
+ pp.Literal(".blend").set_name('bl-operator').set_debug(False)
+ lparen
+ pp.Group(options).set_name('bl-options').set_debug(False)
+ rparen
)
blend.set_name('blend')
blend.set_debug(False)
blend.set_parse_action(make_operator_object)
def make_prompt_from_quoted_string(x): # an operator to direct stable diffusion to step multiple times, once for each target, and then add the results together with different weights
#print(' got quoted prompt', x) explicit_conjunction = (
lparen
+ pp.Group(pp.delimited_list(pp.Group(potential_operator_target | quoted_prompt), min=1)).set_name('cj-target').set_debug(False)
+ rparen
+ pp.one_of([".and", ".add"]).set_name('cj-operator').set_debug(False)
+ lparen
+ pp.Group(options).set_name('cj-options').set_debug(False)
+ rparen
)
explicit_conjunction.set_name('explicit_conjunction')
explicit_conjunction.set_debug(False)
explicit_conjunction.set_parse_action(make_operator_object)
x_unquoted = x[0][1:-1] # by default a prompt consists of a Conjunction with a single term
if len(x_unquoted.strip()) == 0: implicit_conjunction = (blend | pp.Group(prompt)) + pp.StringEnd()
# print(' b : just an empty string')
return Prompt([Fragment('')])
#print(f' b parsing \'{x_unquoted}\'')
x_parsed = prompt.parse_string(x_unquoted)
#print(" quoted prompt was parsed to", type(x_parsed),":", x_parsed)
return x_parsed[0]
quoted_prompt = pp.dbl_quoted_string.set_parse_action(make_prompt_from_quoted_string)
quoted_prompt.set_name('quoted_prompt')
debug_blend=False
blend_terms = pp.delimited_list(quoted_prompt).set_name('blend_terms').set_debug(debug_blend)
blend_weights = (pp.delimited_list(number) + pp.Optional(pp.Char(",").suppress() + "no_normalize")).set_name('blend_weights').set_debug(debug_blend)
blend = pp.Group(lparen + pp.Group(blend_terms) + rparen
+ pp.Literal(".blend").suppress()
+ lparen + pp.Group(blend_weights) + rparen).set_name('blend')
blend.set_debug(debug_blend)
def make_blend(x):
prompts = x[0][0]
weights = x[0][1]
normalize = True
if weights[-1] == 'no_normalize':
normalize = False
weights = weights[:-1]
return Blend(prompts=prompts, weights=weights, normalize_weights=normalize)
blend.set_parse_action(make_blend)
conjunction_terms = blend_terms.copy().set_name('conjunction_terms')
conjunction_weights = blend_weights.copy().set_name('conjunction_weights')
conjunction_with_parens_and_quotes = pp.Group(lparen + pp.Group(conjunction_terms) + rparen
+ pp.Literal(".and").suppress()
+ lparen + pp.Optional(pp.Group(conjunction_weights)) + rparen).set_name('conjunction')
def make_conjunction(x):
parts_raw = x[0][0]
weights = x[0][1] if len(x[0])>1 else [1.0]*len(parts_raw)
parts = [part for part in parts_raw]
return Conjunction(parts, weights)
conjunction_with_parens_and_quotes.set_parse_action(make_conjunction)
implicit_conjunction = pp.OneOrMore(blend | prompt).set_name('implicit_conjunction')
implicit_conjunction.set_parse_action(lambda x: Conjunction(x)) implicit_conjunction.set_parse_action(lambda x: Conjunction(x))
conjunction = conjunction_with_parens_and_quotes | implicit_conjunction conjunction = (explicit_conjunction | implicit_conjunction)
conjunction.set_debug(False)
# top-level is a conjunction of one or more blends or prompts
return conjunction, prompt return conjunction, prompt
def split_weighted_subprompts(text, skip_normalize=False)->list: def split_weighted_subprompts(text, skip_normalize=False)->list:
""" """
Legacy blend parsing. Legacy blend parsing.

View File

@ -28,8 +28,8 @@ class PromptParserTestCase(unittest.TestCase):
self.assertEqual(make_weighted_conjunction([('', 1)]), parse_prompt('')) self.assertEqual(make_weighted_conjunction([('', 1)]), parse_prompt(''))
def test_basic(self): def test_basic(self):
self.assertEqual(make_weighted_conjunction([('fire flames', 1)]), parse_prompt("fire (flames)"))
self.assertEqual(make_weighted_conjunction([("fire flames", 1)]), parse_prompt("fire flames")) self.assertEqual(make_weighted_conjunction([("fire flames", 1)]), parse_prompt("fire flames"))
self.assertEqual(make_weighted_conjunction([('fire flames', 1)]), parse_prompt("fire (flames)"))
self.assertEqual(make_weighted_conjunction([("fire, flames", 1)]), parse_prompt("fire, flames")) self.assertEqual(make_weighted_conjunction([("fire, flames", 1)]), parse_prompt("fire, flames"))
self.assertEqual(make_weighted_conjunction([("fire, flames , fire", 1)]), parse_prompt("fire, flames , fire")) self.assertEqual(make_weighted_conjunction([("fire, flames , fire", 1)]), parse_prompt("fire, flames , fire"))
self.assertEqual(make_weighted_conjunction([("cat hot-dog eating", 1)]), parse_prompt("cat hot-dog eating")) self.assertEqual(make_weighted_conjunction([("cat hot-dog eating", 1)]), parse_prompt("cat hot-dog eating"))
@ -37,14 +37,25 @@ class PromptParserTestCase(unittest.TestCase):
def test_attention(self): def test_attention(self):
self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("(flames)0.5")) self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("(flames)0.5"))
self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("(flames).attend(0.5)"))
self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("flames.attend(0.5)"))
self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("\"flames\".attend(0.5)"))
self.assertEqual(make_weighted_conjunction([('fire flames', 0.5)]), parse_prompt("(fire flames)0.5")) self.assertEqual(make_weighted_conjunction([('fire flames', 0.5)]), parse_prompt("(fire flames)0.5"))
self.assertEqual(make_weighted_conjunction([('fire flames', 0.5)]), parse_prompt("(fire flames).attend(0.5)"))
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("(flames)+")) self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("(flames)+"))
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("flames+")) self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("flames+"))
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("\"flames\"+")) self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("\"flames\"+"))
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("flames.attend(+)"))
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("(flames).attend(+)"))
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("\"flames\".attend(+)"))
self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("(flames)-")) self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("(flames)-"))
self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("flames-")) self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("flames-"))
self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("\"flames\"-")) self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("\"flames\"-"))
self.assertEqual(make_weighted_conjunction([('fire', 1), ('flames', 0.5)]), parse_prompt("fire (flames)0.5")) self.assertEqual(make_weighted_conjunction([('fire', 1), ('flames', 0.5)]), parse_prompt("fire (flames)0.5"))
self.assertEqual(make_weighted_conjunction([('fire', 1), ('flames', 0.5)]), parse_prompt("fire flames.attend(0.5)"))
self.assertEqual(make_weighted_conjunction([('fire', 1), ('flames', 0.5)]), parse_prompt("fire (flames).attend(0.5)"))
self.assertEqual(make_weighted_conjunction([('fire', 1), ('flames', 0.5)]), parse_prompt("fire \"flames\".attend(0.5)"))
self.assertEqual(make_weighted_conjunction([('flames', pow(1.1, 2))]), parse_prompt("(flames)++")) self.assertEqual(make_weighted_conjunction([('flames', pow(1.1, 2))]), parse_prompt("(flames)++"))
self.assertEqual(make_weighted_conjunction([('flames', pow(0.9, 2))]), parse_prompt("(flames)--")) self.assertEqual(make_weighted_conjunction([('flames', pow(0.9, 2))]), parse_prompt("(flames)--"))
self.assertEqual(make_weighted_conjunction([('flowers', pow(0.9, 3)), ('flames', pow(1.1, 3))]), parse_prompt("(flowers)--- flames+++")) self.assertEqual(make_weighted_conjunction([('flowers', pow(0.9, 3)), ('flames', pow(1.1, 3))]), parse_prompt("(flowers)--- flames+++"))
@ -102,20 +113,17 @@ class PromptParserTestCase(unittest.TestCase):
assert_if_prompt_string_not_untouched('a test prompt') assert_if_prompt_string_not_untouched('a test prompt')
assert_if_prompt_string_not_untouched('a badly formed +test prompt') assert_if_prompt_string_not_untouched('a badly formed +test prompt')
with self.assertRaises(pyparsing.ParseException): assert_if_prompt_string_not_untouched('a badly (formed test prompt')
parse_prompt('a badly (formed test prompt')
#with self.assertRaises(pyparsing.ParseException): #with self.assertRaises(pyparsing.ParseException):
with self.assertRaises(pyparsing.ParseException): assert_if_prompt_string_not_untouched('a badly (formed +test prompt')
parse_prompt('a badly (formed +test prompt')
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a badly formed +test prompt',1)])]) , parse_prompt('a badly (formed +test )prompt')) self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a badly formed +test prompt',1)])]) , parse_prompt('a badly (formed +test )prompt'))
with self.assertRaises(pyparsing.ParseException): self.assertEqual(Conjunction([FlattenedPrompt([Fragment('(((a badly formed +test prompt',1)])]) , parse_prompt('(((a badly (formed +test )prompt'))
parse_prompt('(((a badly (formed +test )prompt')
with self.assertRaises(pyparsing.ParseException): self.assertEqual(Conjunction([FlattenedPrompt([Fragment('(a ba dly f ormed +test prompt',1)])]) , parse_prompt('(a (ba)dly (f)ormed +test prompt'))
parse_prompt('(a (ba)dly (f)ormed +test prompt') self.assertEqual(Conjunction([FlattenedPrompt([Fragment('(a ba dly f ormed +test +prompt',1)])]) , parse_prompt('(a (ba)dly (f)ormed +test +prompt'))
with self.assertRaises(pyparsing.ParseException): self.assertEqual(Conjunction([Blend([FlattenedPrompt([Fragment('((a badly (formed +test', 1)])], [1.0])]),
parse_prompt('(a (ba)dly (f)ormed +test +prompt') parse_prompt('("((a badly (formed +test ").blend(1.0)'))
with self.assertRaises(pyparsing.ParseException):
parse_prompt('("((a badly (formed +test ").blend(1.0)')
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('hamburger bun', 1)])]), self.assertEqual(Conjunction([FlattenedPrompt([Fragment('hamburger bun', 1)])]),
parse_prompt("hamburger ((bun))")) parse_prompt("hamburger ((bun))"))
@ -128,6 +136,26 @@ class PromptParserTestCase(unittest.TestCase):
def test_blend(self): def test_blend(self):
self.assertEqual(Conjunction(
[Blend([FlattenedPrompt([('mountain', 1.0)]), FlattenedPrompt([('man', 1.0)])], [1.0, 1.0])]),
parse_prompt("(\"mountain\", \"man\").blend()")
)
self.assertEqual(Conjunction(
[Blend([FlattenedPrompt([('mountain', 1.0)]), FlattenedPrompt([('man', 1.0)])], [1.0, 1.0])]),
parse_prompt("(mountain, man).blend()")
)
self.assertEqual(Conjunction(
[Blend([FlattenedPrompt([('mountain', 1.0)]), FlattenedPrompt([('man', 1.0)])], [1.0, 1.0])]),
parse_prompt("((mountain), (man)).blend()")
)
self.assertEqual(Conjunction(
[Blend([FlattenedPrompt([('mountain', 1.0)]), FlattenedPrompt([('tall man', 1.0)])], [1.0, 1.0])]),
parse_prompt("((mountain), (tall man)).blend()")
)
with self.assertRaises(PromptParser.ParsingException):
print(parse_prompt("((mountain), \"cat.swap(dog)\").blend()"))
self.assertEqual(Conjunction( self.assertEqual(Conjunction(
[Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('fire flames', 1.0)])], [0.7, 0.3])]), [Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('fire flames', 1.0)])], [0.7, 0.3])]),
parse_prompt("(\"fire\", \"fire flames\").blend(0.7, 0.3)") parse_prompt("(\"fire\", \"fire flames\").blend(0.7, 0.3)")
@ -166,10 +194,20 @@ class PromptParserTestCase(unittest.TestCase):
) )
self.assertEqual( self.assertEqual(
Conjunction([Blend([FlattenedPrompt([('mountain, man, hairy', 1)]), Conjunction([Blend([FlattenedPrompt([('mountain , man , hairy', 1)]),
FlattenedPrompt([('face, teeth,', 1), ('eyes', 0.9*0.9)])], weights=[1.0,-1.0])]), FlattenedPrompt([('face , teeth ,', 1), ('eyes', 0.9*0.9)])], weights=[1.0,-1.0], normalize_weights=True)]),
parse_prompt('("mountain, man, hairy", "face, teeth, eyes--").blend(1,-1)') parse_prompt('("mountain, man, hairy", "face, teeth, eyes--").blend(1,-1)')
) )
self.assertEqual(
Conjunction([Blend([FlattenedPrompt([('mountain , man , hairy', 1)]),
FlattenedPrompt([('face , teeth ,', 1), ('eyes', 0.9 * 0.9)])], weights=[1.0, -1.0], normalize_weights=False)]),
parse_prompt('("mountain, man, hairy", "face, teeth, eyes--").blend(1,-1,no_normalize)')
)
with self.assertRaises(PromptParser.ParsingException):
parse_prompt("(\"fire\", \"fire flames\").blend(0.7, 0.3, 0.1)")
with self.assertRaises(PromptParser.ParsingException):
parse_prompt("(\"fire\", \"fire flames\").blend(0.7)")
def test_nested(self): def test_nested(self):
@ -182,6 +220,9 @@ class PromptParserTestCase(unittest.TestCase):
def test_cross_attention_control(self): def test_cross_attention_control(self):
self.assertEqual(Conjunction([FlattenedPrompt([CrossAttentionControlSubstitute([Fragment('sun')], [Fragment('moon')])])]),
parse_prompt("sun.swap(moon)"))
self.assertEqual(Conjunction([ self.assertEqual(Conjunction([
FlattenedPrompt([Fragment('a', 1), FlattenedPrompt([Fragment('a', 1),
CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)]), CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)]),
@ -259,6 +300,12 @@ class PromptParserTestCase(unittest.TestCase):
Fragment(',', 1), Fragment('fire', 2.0)])]) Fragment(',', 1), Fragment('fire', 2.0)])])
self.assertEqual(flames_to_trees_fire, parse_prompt('"(fire (flames)0.5)0.5".swap("(trees)0.7 houses"), (fire)2.0')) self.assertEqual(flames_to_trees_fire, parse_prompt('"(fire (flames)0.5)0.5".swap("(trees)0.7 houses"), (fire)2.0'))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1),
CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]),
Fragment('eating a', 1),
CrossAttentionControlSubstitute([Fragment('hotdog',1)], [Fragment('hotdog', pow(1.1,4))])
])]),
parse_prompt("a cat.swap(dog) eating a hotdog.swap(hotdog++++)"))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1), self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1),
CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]), CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]),
Fragment('eating a', 1), Fragment('eating a', 1),
@ -343,31 +390,31 @@ class PromptParserTestCase(unittest.TestCase):
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain', 1.1), ('\(man\)', 1.1*1.1)]),parse_prompt('hairy (mountain (\(man\))+)+')) self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain', 1.1), ('\(man\)', 1.1*1.1)]),parse_prompt('hairy (mountain (\(man\))+)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('\(man\)', 1.1*1.1), ('mountain', 1.1)]),parse_prompt('hairy ((\(man\))1.1 "mountain")+')) self.assertEqual(make_weighted_conjunction([('hairy', 1), ('\(man\)', 1.1*1.1), ('mountain', 1.1)]),parse_prompt('hairy ((\(man\))1.1 "mountain")+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain', 1.1), ('\(man\)', 1.1*1.1)]),parse_prompt('hairy ("mountain" (\(man\))1.1 )+')) self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain', 1.1), ('\(man\)', 1.1*1.1)]),parse_prompt('hairy ("mountain" (\(man\))1.1 )+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, man', 1.1)]),parse_prompt('hairy ("mountain, man")+')) self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , man', 1.1)]),parse_prompt('hairy ("mountain, man")+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, man with a', 1.1), ('beard', 1.1*1.1)]), parse_prompt('hairy ("mountain, man" with a beard+)+')) self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , man with a', 1.1), ('beard', 1.1*1.1)]), parse_prompt('hairy ("mountain, man" with a beard+)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, man with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, man" with a (beard)2.0)+')) self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , man with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, man" with a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man\" with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\"man\\"" with a (beard)2.0)+')) self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , \"man\" with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\"man\\"" with a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, m\"an\" with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, m\\"an\\"" with a (beard)2.0)+')) self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , m\"an\" with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, m\\"an\\"" with a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man (with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" \(with a (beard)2.0)+')) self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , \"man (with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" \(with a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man w(ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" w\(ith a (beard)2.0)+')) self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , \"man w(ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" w\(ith a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man with( a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" with\( a (beard)2.0)+')) self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , \"man with( a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" with\( a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man )with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" \)with a (beard)2.0)+')) self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , \"man )with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" \)with a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" w\)ith a (beard)2.0)+')) self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , \"man w)ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" w\)ith a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man with) a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" with\) a (beard)2.0)+')) self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , \"man with) a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" with\) a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mou)ntain, \"man (wit(h a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mou\)ntain, \\\"man\" \(wit\(h a (beard)2.0)+')) self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mou)ntain , \"man (wit(h a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mou\)ntain, \\\"man\" \(wit\(h a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hai(ry', 1), ('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hai\(ry ("mountain, \\\"man\" w\)ith a (beard)2.0)+')) self.assertEqual(make_weighted_conjunction([('hai(ry', 1), ('mountain , \"man w)ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hai\(ry ("mountain, \\\"man\" w\)ith a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('hairy((', 1), ('mountain, \"man with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy\(\( ("mountain, \\\"man\" with a (beard)2.0)+')) self.assertEqual(make_weighted_conjunction([('hairy((', 1), ('mountain , \"man with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy\(\( ("mountain, \\\"man\" with a (beard)2.0)+'))
self.assertEqual(make_weighted_conjunction([('mountain, \"man (with a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" \(with a (beard)2.0)+ hairy')) self.assertEqual(make_weighted_conjunction([('mountain , \"man (with a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" \(with a (beard)2.0)+ hairy'))
self.assertEqual(make_weighted_conjunction([('mountain, \"man w(ith a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" w\(ith a (beard)2.0)+hairy')) self.assertEqual(make_weighted_conjunction([('mountain , \"man w(ith a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" w\(ith a (beard)2.0)+hairy'))
self.assertEqual(make_weighted_conjunction([('mountain, \"man with( a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" with\( a (beard)2.0)+ hairy')) self.assertEqual(make_weighted_conjunction([('mountain , \"man with( a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" with\( a (beard)2.0)+ hairy'))
self.assertEqual(make_weighted_conjunction([('mountain, \"man )with a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" \)with a (beard)2.0)+ hairy')) self.assertEqual(make_weighted_conjunction([('mountain , \"man )with a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" \)with a (beard)2.0)+ hairy'))
self.assertEqual(make_weighted_conjunction([('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" w\)ith a (beard)2.0)+ hairy')) self.assertEqual(make_weighted_conjunction([('mountain , \"man w)ith a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" w\)ith a (beard)2.0)+ hairy'))
self.assertEqual(make_weighted_conjunction([('mountain, \"man with) a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt(' ("mountain, \\\"man\" with\) a (beard)2.0)+ hairy')) self.assertEqual(make_weighted_conjunction([('mountain , \"man with) a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt(' ("mountain, \\\"man\" with\) a (beard)2.0)+ hairy'))
self.assertEqual(make_weighted_conjunction([('mou)ntain, \"man (wit(h a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mou\)ntain, \\\"man\" \(wit\(h a (beard)2.0)+ hairy')) self.assertEqual(make_weighted_conjunction([('mou)ntain , \"man (wit(h a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mou\)ntain, \\\"man\" \(wit\(h a (beard)2.0)+ hairy'))
self.assertEqual(make_weighted_conjunction([('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0), ('hai(ry', 1)]), parse_prompt('("mountain, \\\"man\" w\)ith a (beard)2.0)+ hai\(ry ')) self.assertEqual(make_weighted_conjunction([('mountain , \"man w)ith a', 1.1), ('beard', 1.1*2.0), ('hai(ry', 1)]), parse_prompt('("mountain, \\\"man\" w\)ith a (beard)2.0)+ hai\(ry '))
self.assertEqual(make_weighted_conjunction([('mountain, \"man with a', 1.1), ('beard', 1.1*2.0), ('hairy((', 1)]), parse_prompt('("mountain, \\\"man\" with a (beard)2.0)+ hairy\(\( ')) self.assertEqual(make_weighted_conjunction([('mountain , \"man with a', 1.1), ('beard', 1.1*2.0), ('hairy((', 1)]), parse_prompt('("mountain, \\\"man\" with a (beard)2.0)+ hairy\(\( '))
def test_cross_attention_escaping(self): def test_cross_attention_escaping(self):
@ -433,6 +480,15 @@ class PromptParserTestCase(unittest.TestCase):
def test_single(self): def test_single(self):
self.assertEqual(Conjunction([FlattenedPrompt([("mountain man", 1.0)]),
FlattenedPrompt([("a person with a hat", 1.0),
("riding a", 1.1*1.1),
CrossAttentionControlSubstitute(
[Fragment("bicycle", pow(1.1,2))],
[Fragment("skateboard", pow(1.1,2))])
])
], weights=[0.5, 0.5]),
parse_prompt("(\"mountain man\", \"a person with a hat (riding a bicycle.swap(skateboard))++\").and(0.5, 0.5)"))
pass pass