mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Rebuilt prompt parsing logic
Complete re-write of the prompt parsing logic to be more readable and logical, and therefore also hopefully easier to debug, maintain, and augment. In the process it has also become more robust to badly-formed prompts. Squashed commit of the following: commit 8fcfa88a16e1390d41717e940d72aed64712171c Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Sun Oct 30 17:05:57 2022 +0100 further cleanup commit 1a1fd78bcfeb49d072e3e6d5808aa8df15441629 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Sun Oct 30 16:07:57 2022 +0100 cleanup and document commit 099c9659fa8b8135876f9a5a50fe80b20bc0635c Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Sun Oct 30 15:54:58 2022 +0100 works fully commit 5e6887ea8c25a1e21438ff6defb381fd027d25fd Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Sun Oct 30 15:24:31 2022 +0100 further... commit 492fda120844d9bc1ad4ec7dd408a3374762d0ff Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Sun Oct 30 14:08:57 2022 +0100 getting there... commit c6aab05a8450cc3c95c8691daf38fdc64c74f52d Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Fri Oct 28 14:29:03 2022 +0200 wip doesn't compile commit 5e533f731cfd20cd435330eeb0012e5689e87e81 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Fri Oct 28 13:21:43 2022 +0200 working with CrossAttentionCtonrol but no Attention support yet commit 9678348773431e500e110e8aede99086bb7b5955 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Fri Oct 28 13:04:52 2022 +0200 wip rebuiling prompt parser
This commit is contained in:
parent
349cc25433
commit
e554c2607f
@ -28,7 +28,7 @@ class Prompt():
|
||||
def __init__(self, parts: list):
|
||||
for c in parts:
|
||||
if type(c) is not Attention and not issubclass(type(c), BaseFragment) and type(c) is not pp.ParseResults:
|
||||
raise PromptParser.ParsingException(f"Prompt cannot contain {type(c).__name__} {c}, only {BaseFragment.__subclasses__()} are allowed")
|
||||
raise PromptParser.ParsingException(f"Prompt cannot contain {type(c).__name__} ({c}), only {[c.__name__ for c in BaseFragment.__subclasses__()]} are allowed")
|
||||
self.children = parts
|
||||
def __repr__(self):
|
||||
return f"Prompt:{self.children}"
|
||||
@ -102,12 +102,18 @@ class Attention():
|
||||
Do not traverse directly; instead obtain a FlattenedPrompt by calling Flatten() on a top-level Conjunction object.
|
||||
"""
|
||||
def __init__(self, weight: float, children: list):
|
||||
if type(weight) is not float:
|
||||
raise PromptParser.ParsingException(
|
||||
f"Attention weight must be float (got {type(weight).__name__} {weight})")
|
||||
self.weight = weight
|
||||
if type(children) is not list:
|
||||
raise PromptParser.ParsingException(f"cannot make Attention with non-list of children (got {type(children)})")
|
||||
assert(type(children) is list)
|
||||
self.children = children
|
||||
#print(f"A: requested attention '{children}' to {weight}")
|
||||
|
||||
def __repr__(self):
|
||||
return f"Attention:'{self.children}' @ {self.weight}"
|
||||
return f"Attention:{self.children} * {self.weight}"
|
||||
def __eq__(self, other):
|
||||
return type(other) is Attention and other.weight == self.weight and other.fragment == self.fragment
|
||||
|
||||
@ -136,9 +142,9 @@ class CrossAttentionControlSubstitute(CrossAttentionControlledFragment):
|
||||
Fragment('sitting on a car')
|
||||
])
|
||||
"""
|
||||
def __init__(self, original: Union[Fragment, list], edited: Union[Fragment, list], options: dict=None):
|
||||
def __init__(self, original: list, edited: list, options: dict=None):
|
||||
self.original = original
|
||||
self.edited = edited
|
||||
self.edited = edited if len(edited)>0 else [Fragment('')]
|
||||
|
||||
default_options = {
|
||||
's_start': 0.0,
|
||||
@ -190,12 +196,12 @@ class Conjunction():
|
||||
"""
|
||||
def __init__(self, prompts: list, weights: list = None):
|
||||
# force everything to be a Prompt
|
||||
#print("making conjunction with", parts)
|
||||
#print("making conjunction with", prompts, "types", [type(p).__name__ for p in prompts])
|
||||
self.prompts = [x if (type(x) is Prompt
|
||||
or type(x) is Blend
|
||||
or type(x) is FlattenedPrompt)
|
||||
else Prompt(x) for x in prompts]
|
||||
self.weights = [1.0]*len(self.prompts) if weights is None else list(weights)
|
||||
self.weights = [1.0]*len(self.prompts) if (weights is None or len(weights)==0) else list(weights)
|
||||
if len(self.weights) != len(self.prompts):
|
||||
raise PromptParser.ParsingException(f"while parsing Conjunction: mismatched parts/weights counts {prompts}, {weights}")
|
||||
self.type = 'AND'
|
||||
@ -216,6 +222,7 @@ class Blend():
|
||||
"""
|
||||
def __init__(self, prompts: list, weights: list[float], normalize_weights: bool=True):
|
||||
#print("making Blend with prompts", prompts, "and weights", weights)
|
||||
weights = [1.0]*len(prompts) if (weights is None or len(weights)==0) else list(weights)
|
||||
if len(prompts) != len(weights):
|
||||
raise PromptParser.ParsingException(f"while parsing Blend: mismatched prompts/weights counts {prompts}, {weights}")
|
||||
for p in prompts:
|
||||
@ -244,6 +251,10 @@ class PromptParser():
|
||||
class ParsingException(Exception):
|
||||
pass
|
||||
|
||||
class UnrecognizedOperatorException(Exception):
|
||||
def __init__(self, operator:str):
|
||||
super().__init__("Unrecognized operator: " + operator)
|
||||
|
||||
def __init__(self, attention_plus_base=1.1, attention_minus_base=0.9):
|
||||
|
||||
self.conjunction, self.prompt = build_parser_syntax(attention_plus_base, attention_minus_base)
|
||||
@ -279,7 +290,7 @@ class PromptParser():
|
||||
return Blend(prompts=flattened_prompts, weights=weights, normalize_weights=True)
|
||||
|
||||
|
||||
def flatten(self, root: Conjunction) -> Conjunction:
|
||||
def flatten(self, root: Conjunction, verbose = False) -> Conjunction:
|
||||
"""
|
||||
Flattening a Conjunction traverses all of the nested tree-like structures in each of its Prompts or Blends,
|
||||
producing from each of these walks a linear sequence of Fragment or CrossAttentionControlSubstitute objects
|
||||
@ -289,8 +300,6 @@ class PromptParser():
|
||||
:return: A Conjunction containing the result of flattening each of the prompts in the passed-in root.
|
||||
"""
|
||||
|
||||
#print("flattening", root)
|
||||
|
||||
def fuse_fragments(items):
|
||||
# print("fusing fragments in ", items)
|
||||
result = []
|
||||
@ -313,8 +322,8 @@ class PromptParser():
|
||||
return result
|
||||
|
||||
def flatten_internal(node, weight_scale, results, prefix):
|
||||
#print(prefix + "flattening", node, "...")
|
||||
if type(node) is pp.ParseResults:
|
||||
verbose and print(prefix + "flattening", node, "...")
|
||||
if type(node) is pp.ParseResults or type(node) is list:
|
||||
for x in node:
|
||||
results = flatten_internal(x, weight_scale, results, prefix+' pr ')
|
||||
#print(prefix, " ParseResults expanded, results is now", results)
|
||||
@ -345,67 +354,59 @@ class PromptParser():
|
||||
#print(prefix + "after flattening Prompt, results is", results)
|
||||
else:
|
||||
raise PromptParser.ParsingException(f"unhandled node type {type(node)} when flattening {node}")
|
||||
#print(prefix + "-> after flattening", type(node).__name__, "results is", results)
|
||||
verbose and print(prefix + "-> after flattening", type(node).__name__, "results is", results)
|
||||
return results
|
||||
|
||||
verbose and print("flattening", root)
|
||||
|
||||
flattened_parts = []
|
||||
for part in root.prompts:
|
||||
flattened_parts += flatten_internal(part, 1.0, [], ' C| ')
|
||||
|
||||
#print("flattened to", flattened_parts)
|
||||
verbose and print("flattened to", flattened_parts)
|
||||
|
||||
weights = root.weights
|
||||
return Conjunction(flattened_parts, weights)
|
||||
|
||||
|
||||
|
||||
|
||||
def build_parser_syntax(attention_plus_base: float, attention_minus_base: float):
|
||||
def make_operator_object(x):
|
||||
#print('making operator for', x)
|
||||
target = x[0]
|
||||
operator = x[1]
|
||||
arguments = x[2]
|
||||
if operator == '.attend':
|
||||
weight_raw = arguments[0]
|
||||
weight = 1.0
|
||||
if type(weight_raw) is float or type(weight_raw) is int:
|
||||
weight = weight_raw
|
||||
elif type(weight_raw) is str:
|
||||
base = attention_plus_base if weight_raw[0] == '+' else attention_minus_base
|
||||
weight = pow(base, len(weight_raw))
|
||||
return Attention(weight=weight, children=[x for x in x[0]])
|
||||
elif operator == '.swap':
|
||||
return CrossAttentionControlSubstitute(target, arguments, x.as_dict())
|
||||
elif operator == '.blend':
|
||||
prompts = [Prompt(p) for p in x[0]]
|
||||
weights_raw = x[2]
|
||||
normalize_weights = True
|
||||
if len(weights_raw) > 0 and weights_raw[-1][0] == 'no_normalize':
|
||||
normalize_weights = False
|
||||
weights_raw = weights_raw[:-1]
|
||||
weights = [float(w[0]) for w in weights_raw]
|
||||
return Blend(prompts=prompts, weights=weights, normalize_weights=normalize_weights)
|
||||
elif operator == '.and' or operator == '.add':
|
||||
prompts = [Prompt(p) for p in x[0]]
|
||||
weights = [float(w[0]) for w in x[2]]
|
||||
return Conjunction(prompts=prompts, weights=weights)
|
||||
|
||||
lparen = pp.Literal("(").suppress()
|
||||
rparen = pp.Literal(")").suppress()
|
||||
quotes = pp.Literal('"').suppress()
|
||||
comma = pp.Literal(",").suppress()
|
||||
raise PromptParser.UnrecognizedOperatorException(operator)
|
||||
|
||||
# accepts int or float notation, always maps to float
|
||||
number = pp.pyparsing_common.real | \
|
||||
pp.Combine(pp.Optional("-")+pp.Word(pp.nums)).set_parse_action(pp.token_map(float))
|
||||
|
||||
attention = pp.Forward()
|
||||
quoted_fragment = pp.Forward()
|
||||
parenthesized_fragment = pp.Forward()
|
||||
cross_attention_substitute = pp.Forward()
|
||||
|
||||
def make_text_fragment(x):
|
||||
#print("### making fragment for", x)
|
||||
if type(x[0]) is Fragment:
|
||||
assert(False)
|
||||
if type(x) is str:
|
||||
return Fragment(x)
|
||||
elif type(x) is pp.ParseResults or type(x) is list:
|
||||
#print(f'converting {type(x).__name__} to Fragment')
|
||||
return Fragment(' '.join([s for s in x]))
|
||||
else:
|
||||
raise PromptParser.ParsingException("Cannot make fragment from " + str(x))
|
||||
|
||||
def build_escaped_word_parser_charbychar(escaped_chars_to_ignore: str):
|
||||
escapes = []
|
||||
for c in escaped_chars_to_ignore:
|
||||
escapes.append(pp.Literal('\\'+c))
|
||||
return pp.Combine(pp.OneOrMore(
|
||||
pp.MatchFirst(escapes + [pp.CharsNotIn(
|
||||
string.whitespace + escaped_chars_to_ignore,
|
||||
exact=1
|
||||
)])
|
||||
))
|
||||
|
||||
|
||||
|
||||
def parse_fragment_str(x, in_quotes: bool=False, in_parens: bool=False):
|
||||
def parse_fragment_str(x, expression: pp.ParseExpression, in_quotes: bool = False, in_parens: bool = False):
|
||||
#print(f"parsing fragment string for {x}")
|
||||
fragment_string = x[0]
|
||||
#print(f"ppparsing fragment string \"{fragment_string}\"")
|
||||
|
||||
if len(fragment_string.strip()) == 0:
|
||||
return Fragment('')
|
||||
|
||||
@ -413,234 +414,198 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float)
|
||||
# escape unescaped quotes
|
||||
fragment_string = fragment_string.replace('"', '\\"')
|
||||
|
||||
#fragment_parser = pp.Group(pp.OneOrMore(attention | cross_attention_substitute | (greedy_word.set_parse_action(make_text_fragment))))
|
||||
try:
|
||||
result = pp.Group(pp.MatchFirst([
|
||||
pp.OneOrMore(quoted_fragment | attention | unquoted_word).set_name('pf_str_qfuq'),
|
||||
pp.Empty().set_parse_action(make_text_fragment) + pp.StringEnd()
|
||||
])).set_name('blend-result').set_debug(False).parse_string(fragment_string)
|
||||
result = (expression + pp.StringEnd()).parse_string(fragment_string)
|
||||
#print("parsed to", result)
|
||||
return result
|
||||
except pp.ParseException as e:
|
||||
#print("parse_fragment_str couldn't parse prompt string:", e)
|
||||
raise
|
||||
|
||||
# meaningful symbols
|
||||
lparen = pp.Literal("(").suppress()
|
||||
rparen = pp.Literal(")").suppress()
|
||||
quote = pp.Literal('"').suppress()
|
||||
comma = pp.Literal(",").suppress()
|
||||
dot = pp.Literal(".").suppress()
|
||||
equals = pp.Literal("=").suppress()
|
||||
|
||||
escaped_lparen = pp.Literal('\\(')
|
||||
escaped_rparen = pp.Literal('\\)')
|
||||
escaped_quote = pp.Literal('\\"')
|
||||
escaped_comma = pp.Literal('\\,')
|
||||
escaped_dot = pp.Literal('\\.')
|
||||
escaped_plus = pp.Literal('\\+')
|
||||
escaped_minus = pp.Literal('\\-')
|
||||
escaped_equals = pp.Literal('\\=')
|
||||
|
||||
syntactic_symbols = {
|
||||
'(': escaped_lparen,
|
||||
')': escaped_rparen,
|
||||
'"': escaped_quote,
|
||||
',': escaped_comma,
|
||||
'.': escaped_dot,
|
||||
'+': escaped_plus,
|
||||
'-': escaped_minus,
|
||||
'=': escaped_equals,
|
||||
}
|
||||
syntactic_chars = "".join(syntactic_symbols.keys())
|
||||
|
||||
# accepts int or float notation, always maps to float
|
||||
number = pp.pyparsing_common.real | \
|
||||
pp.Combine(pp.Optional("-")+pp.Word(pp.nums)).set_parse_action(pp.token_map(float))
|
||||
|
||||
# for options
|
||||
keyword = pp.Word(pp.alphanums + '_')
|
||||
|
||||
# a word that absolutely does not contain any meaningful syntax
|
||||
non_syntax_word = pp.Combine(pp.OneOrMore(pp.MatchFirst([
|
||||
pp.Or(syntactic_symbols.values()),
|
||||
pp.one_of(['-', '+']) + pp.NotAny(pp.White() | pp.Char(syntactic_chars) | pp.StringEnd()),
|
||||
# build character-by-character
|
||||
pp.CharsNotIn(string.whitespace + syntactic_chars, exact=1)
|
||||
])))
|
||||
non_syntax_word.set_parse_action(lambda x: [Fragment(t) for t in x])
|
||||
non_syntax_word.set_name('non_syntax_word')
|
||||
non_syntax_word.set_debug(False)
|
||||
|
||||
# a word that can contain any character at all - greedily consumes syntax, so use with care
|
||||
free_word = pp.CharsNotIn(string.whitespace).set_parse_action(lambda x: Fragment(x[0]))
|
||||
free_word.set_name('free_word')
|
||||
free_word.set_debug(False)
|
||||
|
||||
|
||||
# ok here we go. forward declare some things..
|
||||
attention = pp.Forward()
|
||||
cross_attention_substitute = pp.Forward()
|
||||
parenthesized_fragment = pp.Forward()
|
||||
quoted_fragment = pp.Forward()
|
||||
|
||||
# the types of things that can go into a fragment, consisting of syntax-full and/or strictly syntax-free components
|
||||
fragment_part_expressions = [
|
||||
attention,
|
||||
cross_attention_substitute,
|
||||
parenthesized_fragment,
|
||||
quoted_fragment,
|
||||
non_syntax_word
|
||||
]
|
||||
# a fragment that is permitted to contain commas
|
||||
fragment_including_commas = pp.ZeroOrMore(pp.MatchFirst(
|
||||
fragment_part_expressions + [
|
||||
pp.Literal(',').set_parse_action(lambda x: Fragment(x[0]))
|
||||
]
|
||||
))
|
||||
# a fragment that is not permitted to contain commas
|
||||
fragment_excluding_commas = pp.ZeroOrMore(pp.MatchFirst(
|
||||
fragment_part_expressions
|
||||
))
|
||||
|
||||
# a fragment in double quotes (may be nested)
|
||||
quoted_fragment << pp.QuotedString(quote_char='"', esc_char=None, esc_quote='\\"')
|
||||
quoted_fragment.set_parse_action(lambda x: parse_fragment_str(x, in_quotes=True)).set_name('quoted_fragment')
|
||||
quoted_fragment.set_parse_action(lambda x: parse_fragment_str(x, fragment_including_commas, in_quotes=True))
|
||||
|
||||
escaped_quote = pp.Literal('\\"')#.set_parse_action(lambda x: '"')
|
||||
escaped_lparen = pp.Literal('\\(')#.set_parse_action(lambda x: '(')
|
||||
escaped_rparen = pp.Literal('\\)')#.set_parse_action(lambda x: ')')
|
||||
escaped_backslash = pp.Literal('\\\\')#.set_parse_action(lambda x: '"')
|
||||
# a fragment inside parentheses (may be nested)
|
||||
parenthesized_fragment << (lparen + fragment_including_commas + rparen)
|
||||
parenthesized_fragment.set_name('parenthesized_fragment')
|
||||
parenthesized_fragment.set_debug(False)
|
||||
|
||||
empty = (
|
||||
(lparen + pp.ZeroOrMore(pp.Word(string.whitespace)) + rparen) |
|
||||
(quotes + pp.ZeroOrMore(pp.Word(string.whitespace)) + quotes)).set_debug(False).set_name('empty')
|
||||
|
||||
|
||||
def not_ends_with_swap(x):
|
||||
#print("trying to match:", x)
|
||||
return not x[0].endswith('.swap')
|
||||
|
||||
unquoted_word = (pp.Combine(pp.OneOrMore(
|
||||
escaped_rparen | escaped_lparen | escaped_quote | escaped_backslash |
|
||||
(pp.CharsNotIn(string.whitespace + '\\"()', exact=1)
|
||||
)))
|
||||
# don't whitespace when the next word starts with +, eg "badly +formed"
|
||||
+ (pp.White().suppress() |
|
||||
# don't eat +/-
|
||||
pp.NotAny(pp.Word('+') | pp.Word('-'))
|
||||
)
|
||||
)
|
||||
|
||||
unquoted_word.set_parse_action(make_text_fragment).set_name('unquoted_word').set_debug(False)
|
||||
#print(unquoted_fragment.parse_string("cat.swap(dog)"))
|
||||
|
||||
parenthesized_fragment << (lparen +
|
||||
pp.Or([
|
||||
(parenthesized_fragment),
|
||||
(quoted_fragment.copy().set_parse_action(lambda x: parse_fragment_str(x, in_quotes=True)).set_debug(False)).set_name('-quoted_paren_internal').set_debug(False),
|
||||
(pp.Combine(pp.OneOrMore(
|
||||
escaped_quote | escaped_lparen | escaped_rparen | escaped_backslash |
|
||||
pp.CharsNotIn(string.whitespace + '\\"()', exact=1) |
|
||||
pp.White()
|
||||
)).set_name('--combined').set_parse_action(lambda x: parse_fragment_str(x, in_parens=True)).set_debug(False)),
|
||||
pp.Empty()
|
||||
]) + rparen)
|
||||
parenthesized_fragment.set_name('parenthesized_fragment').set_debug(False)
|
||||
|
||||
debug_attention = False
|
||||
# attention control of the form (phrase)+ / (phrase)+ / (phrase)<weight>
|
||||
# phrase can be multiple words, can have multiple +/- signs to increase the effect or type a floating point or integer weight
|
||||
attention_with_parens = pp.Forward()
|
||||
attention_without_parens = pp.Forward()
|
||||
|
||||
attention_with_parens_foot = (number | pp.Word('+') | pp.Word('-'))\
|
||||
.set_name("attention_foot")\
|
||||
.set_debug(False)
|
||||
attention_with_parens <<= pp.Group(
|
||||
lparen +
|
||||
pp.ZeroOrMore(quoted_fragment | attention_with_parens | parenthesized_fragment | cross_attention_substitute | attention_without_parens |
|
||||
(pp.Empty() + build_escaped_word_parser_charbychar('()')).set_name('undecorated_word').set_debug(debug_attention)#.set_parse_action(lambda t: t[0])
|
||||
)
|
||||
+ rparen + attention_with_parens_foot)
|
||||
attention_with_parens.set_name('attention_with_parens').set_debug(debug_attention)
|
||||
|
||||
attention_without_parens_foot = (pp.NotAny(pp.White()) + pp.Or([pp.Word('+'), pp.Word('-')]) + pp.FollowedBy(pp.StringEnd() | pp.White() | pp.Literal('(') | pp.Literal(')') | pp.Literal(',') | pp.Literal('"')) ).set_name('attention_without_parens_foots')
|
||||
attention_without_parens <<= pp.Group(pp.MatchFirst([
|
||||
quoted_fragment.copy().set_name('attention_quoted_fragment_without_parens').set_debug(debug_attention) + attention_without_parens_foot,
|
||||
pp.Combine(build_escaped_word_parser_charbychar('()+-')).set_name('attention_word_without_parens').set_debug(debug_attention)#.set_parse_action(lambda x: print('escapéd', x))
|
||||
+ attention_without_parens_foot#.leave_whitespace()
|
||||
# a string of the form (<keyword>=<float|keyword> | <float> | <keyword>) where keyword is alphanumeric + '_'
|
||||
option = pp.Group(pp.MatchFirst([
|
||||
keyword + equals + (number | keyword), # option=value
|
||||
number.copy().set_parse_action(pp.token_map(str)), # weight
|
||||
keyword # flag
|
||||
]))
|
||||
attention_without_parens.set_name('attention_without_parens').set_debug(debug_attention)
|
||||
# options for an operator, eg "s_start=0.1, 0.3, no_normalize"
|
||||
options = pp.Dict(pp.Optional(pp.delimited_list(option)))
|
||||
options.set_name('options')
|
||||
options.set_debug(False)
|
||||
|
||||
# a fragment which can be used as the target for an operator - either quoted or in parentheses, or a bare vanilla word
|
||||
potential_operator_target = (quoted_fragment | parenthesized_fragment | non_syntax_word)
|
||||
|
||||
attention << pp.MatchFirst([attention_with_parens,
|
||||
attention_without_parens
|
||||
])
|
||||
# a fragment whose weight has been increased or decreased by a given amount
|
||||
attention_weight_operator = pp.Word('+') | pp.Word('-') | number
|
||||
attention_explicit = (
|
||||
pp.Group(potential_operator_target)
|
||||
+ pp.Literal('.attend')
|
||||
+ lparen
|
||||
+ pp.Group(attention_weight_operator)
|
||||
+ rparen
|
||||
)
|
||||
attention_explicit.set_parse_action(make_operator_object)
|
||||
attention_implicit = (
|
||||
pp.Group(potential_operator_target)
|
||||
+ pp.NotAny(pp.White()) # do not permit whitespace between term and operator
|
||||
+ pp.Group(attention_weight_operator)
|
||||
)
|
||||
attention_implicit.set_parse_action(lambda x: make_operator_object([x[0], '.attend', x[1]]))
|
||||
attention << (attention_explicit | attention_implicit)
|
||||
attention.set_name('attention')
|
||||
attention.set_debug(False)
|
||||
|
||||
def make_attention(x):
|
||||
#print("entered make_attention with", x)
|
||||
children = x[0][:-1]
|
||||
weight_raw = x[0][-1]
|
||||
weight = 1.0
|
||||
if type(weight_raw) is float or type(weight_raw) is int:
|
||||
weight = weight_raw
|
||||
elif type(weight_raw) is str:
|
||||
base = attention_plus_base if weight_raw[0] == '+' else attention_minus_base
|
||||
weight = pow(base, len(weight_raw))
|
||||
|
||||
#print("making Attention from", children, "with weight", weight)
|
||||
|
||||
return Attention(weight=weight, children=[(Fragment(x) if type(x) is str else x) for x in children])
|
||||
|
||||
attention_with_parens.set_parse_action(make_attention)
|
||||
attention_without_parens.set_parse_action(make_attention)
|
||||
|
||||
#print("parsing test:", attention_with_parens.parse_string("mountain (man)1.1"))
|
||||
|
||||
# cross-attention control
|
||||
empty_string = ((lparen + rparen) |
|
||||
pp.Literal('""').suppress() |
|
||||
(lparen + pp.Literal('""').suppress() + rparen)
|
||||
).set_parse_action(lambda x: Fragment(""))
|
||||
empty_string.set_name('empty_string')
|
||||
|
||||
# cross attention control
|
||||
debug_cross_attention_control = False
|
||||
original_fragment = pp.MatchFirst([
|
||||
quoted_fragment.set_debug(debug_cross_attention_control),
|
||||
parenthesized_fragment.set_debug(debug_cross_attention_control),
|
||||
pp.Combine(pp.OneOrMore(pp.CharsNotIn(string.whitespace + '.', exact=1))).set_parse_action(make_text_fragment) + pp.FollowedBy(".swap"),
|
||||
empty_string.set_debug(debug_cross_attention_control),
|
||||
])
|
||||
# support keyword=number arguments
|
||||
cross_attention_option_keyword = pp.Or([pp.Keyword("s_start"), pp.Keyword("s_end"), pp.Keyword("t_start"), pp.Keyword("t_end"), pp.Keyword("shape_freedom")])
|
||||
cross_attention_option = pp.Group(cross_attention_option_keyword + pp.Literal("=").suppress() + number)
|
||||
edited_fragment = pp.MatchFirst([
|
||||
(lparen + rparen).set_parse_action(lambda x: Fragment('')),
|
||||
lparen +
|
||||
(quoted_fragment | attention |
|
||||
pp.Group(pp.ZeroOrMore(build_escaped_word_parser_charbychar(',)').set_parse_action(make_text_fragment)))
|
||||
) +
|
||||
pp.Dict(pp.ZeroOrMore(comma + cross_attention_option)) +
|
||||
rparen,
|
||||
parenthesized_fragment
|
||||
])
|
||||
cross_attention_substitute << original_fragment + pp.Literal(".swap").set_debug(False).suppress() + edited_fragment
|
||||
|
||||
original_fragment.set_name('original_fragment').set_debug(debug_cross_attention_control)
|
||||
edited_fragment.set_name('edited_fragment').set_debug(debug_cross_attention_control)
|
||||
cross_attention_substitute.set_name('cross_attention_substitute').set_debug(debug_cross_attention_control)
|
||||
|
||||
def make_cross_attention_substitute(x):
|
||||
#print("making cacs for", x[0], "->", x[1], "with options", x.as_dict())
|
||||
#if len(x>2):
|
||||
cacs = CrossAttentionControlSubstitute(x[0], x[1], options=x.as_dict())
|
||||
#print("made", cacs)
|
||||
return cacs
|
||||
cross_attention_substitute.set_parse_action(make_cross_attention_substitute)
|
||||
# cross-attention control by swapping one fragment for another
|
||||
cross_attention_substitute << (
|
||||
pp.Group(potential_operator_target).set_name('ca-target').set_debug(False)
|
||||
+ pp.Literal(".swap").set_name('ca-operator').set_debug(False)
|
||||
+ lparen
|
||||
+ pp.Group(fragment_excluding_commas).set_name('ca-replacement').set_debug(False)
|
||||
+ pp.Optional(comma + options).set_name('ca-options').set_debug(False)
|
||||
+ rparen
|
||||
)
|
||||
cross_attention_substitute.set_name('cross_attention_substitute')
|
||||
cross_attention_substitute.set_debug(False)
|
||||
cross_attention_substitute.set_parse_action(make_operator_object)
|
||||
|
||||
|
||||
# root prompt definition
|
||||
debug_root_prompt = False
|
||||
prompt = (pp.OneOrMore(pp.MatchFirst([cross_attention_substitute.set_debug(debug_root_prompt),
|
||||
attention.set_debug(debug_root_prompt),
|
||||
quoted_fragment.set_debug(debug_root_prompt),
|
||||
parenthesized_fragment.set_debug(debug_root_prompt),
|
||||
unquoted_word.set_debug(debug_root_prompt),
|
||||
empty.set_parse_action(make_text_fragment).set_debug(debug_root_prompt)])
|
||||
) + pp.StringEnd()) \
|
||||
.set_name('prompt') \
|
||||
.set_parse_action(lambda x: Prompt(x)) \
|
||||
.set_debug(debug_root_prompt)
|
||||
# an entire self-contained prompt, which can be used in a Blend or Conjunction
|
||||
prompt = pp.ZeroOrMore(pp.MatchFirst([
|
||||
cross_attention_substitute,
|
||||
attention,
|
||||
quoted_fragment,
|
||||
parenthesized_fragment,
|
||||
free_word,
|
||||
pp.White().suppress()
|
||||
]))
|
||||
quoted_prompt = quoted_fragment.copy().set_parse_action(lambda x: parse_fragment_str(x, prompt, in_quotes=True))
|
||||
|
||||
#print("parsing test:", prompt.parse_string("spaced eyes--"))
|
||||
#print("parsing test:", prompt.parse_string("eyes--"))
|
||||
|
||||
# weighted blend of prompts
|
||||
# ("promptA", "promptB").blend(a, b) where "promptA" and "promptB" are valid prompts and a and b are float or
|
||||
# int weights.
|
||||
# can specify more terms eg ("promptA", "promptB", "promptC").blend(a,b,c)
|
||||
# a blend/lerp between the feature vectors for two or more prompts
|
||||
blend = (
|
||||
lparen
|
||||
+ pp.Group(pp.delimited_list(pp.Group(potential_operator_target | quoted_prompt), min=1)).set_name('bl-target').set_debug(False)
|
||||
+ rparen
|
||||
+ pp.Literal(".blend").set_name('bl-operator').set_debug(False)
|
||||
+ lparen
|
||||
+ pp.Group(options).set_name('bl-options').set_debug(False)
|
||||
+ rparen
|
||||
)
|
||||
blend.set_name('blend')
|
||||
blend.set_debug(False)
|
||||
blend.set_parse_action(make_operator_object)
|
||||
|
||||
def make_prompt_from_quoted_string(x):
|
||||
#print(' got quoted prompt', x)
|
||||
# an operator to direct stable diffusion to step multiple times, once for each target, and then add the results together with different weights
|
||||
explicit_conjunction = (
|
||||
lparen
|
||||
+ pp.Group(pp.delimited_list(pp.Group(potential_operator_target | quoted_prompt), min=1)).set_name('cj-target').set_debug(False)
|
||||
+ rparen
|
||||
+ pp.one_of([".and", ".add"]).set_name('cj-operator').set_debug(False)
|
||||
+ lparen
|
||||
+ pp.Group(options).set_name('cj-options').set_debug(False)
|
||||
+ rparen
|
||||
)
|
||||
explicit_conjunction.set_name('explicit_conjunction')
|
||||
explicit_conjunction.set_debug(False)
|
||||
explicit_conjunction.set_parse_action(make_operator_object)
|
||||
|
||||
x_unquoted = x[0][1:-1]
|
||||
if len(x_unquoted.strip()) == 0:
|
||||
# print(' b : just an empty string')
|
||||
return Prompt([Fragment('')])
|
||||
#print(f' b parsing \'{x_unquoted}\'')
|
||||
x_parsed = prompt.parse_string(x_unquoted)
|
||||
#print(" quoted prompt was parsed to", type(x_parsed),":", x_parsed)
|
||||
return x_parsed[0]
|
||||
|
||||
quoted_prompt = pp.dbl_quoted_string.set_parse_action(make_prompt_from_quoted_string)
|
||||
quoted_prompt.set_name('quoted_prompt')
|
||||
|
||||
debug_blend=False
|
||||
blend_terms = pp.delimited_list(quoted_prompt).set_name('blend_terms').set_debug(debug_blend)
|
||||
blend_weights = (pp.delimited_list(number) + pp.Optional(pp.Char(",").suppress() + "no_normalize")).set_name('blend_weights').set_debug(debug_blend)
|
||||
blend = pp.Group(lparen + pp.Group(blend_terms) + rparen
|
||||
+ pp.Literal(".blend").suppress()
|
||||
+ lparen + pp.Group(blend_weights) + rparen).set_name('blend')
|
||||
blend.set_debug(debug_blend)
|
||||
|
||||
def make_blend(x):
|
||||
prompts = x[0][0]
|
||||
weights = x[0][1]
|
||||
normalize = True
|
||||
if weights[-1] == 'no_normalize':
|
||||
normalize = False
|
||||
weights = weights[:-1]
|
||||
return Blend(prompts=prompts, weights=weights, normalize_weights=normalize)
|
||||
|
||||
blend.set_parse_action(make_blend)
|
||||
|
||||
conjunction_terms = blend_terms.copy().set_name('conjunction_terms')
|
||||
conjunction_weights = blend_weights.copy().set_name('conjunction_weights')
|
||||
conjunction_with_parens_and_quotes = pp.Group(lparen + pp.Group(conjunction_terms) + rparen
|
||||
+ pp.Literal(".and").suppress()
|
||||
+ lparen + pp.Optional(pp.Group(conjunction_weights)) + rparen).set_name('conjunction')
|
||||
def make_conjunction(x):
|
||||
parts_raw = x[0][0]
|
||||
weights = x[0][1] if len(x[0])>1 else [1.0]*len(parts_raw)
|
||||
parts = [part for part in parts_raw]
|
||||
return Conjunction(parts, weights)
|
||||
conjunction_with_parens_and_quotes.set_parse_action(make_conjunction)
|
||||
|
||||
implicit_conjunction = pp.OneOrMore(blend | prompt).set_name('implicit_conjunction')
|
||||
# by default a prompt consists of a Conjunction with a single term
|
||||
implicit_conjunction = (blend | pp.Group(prompt)) + pp.StringEnd()
|
||||
implicit_conjunction.set_parse_action(lambda x: Conjunction(x))
|
||||
|
||||
conjunction = conjunction_with_parens_and_quotes | implicit_conjunction
|
||||
conjunction.set_debug(False)
|
||||
conjunction = (explicit_conjunction | implicit_conjunction)
|
||||
|
||||
# top-level is a conjunction of one or more blends or prompts
|
||||
return conjunction, prompt
|
||||
|
||||
|
||||
|
||||
def split_weighted_subprompts(text, skip_normalize=False)->list:
|
||||
"""
|
||||
Legacy blend parsing.
|
||||
|
@ -28,8 +28,8 @@ class PromptParserTestCase(unittest.TestCase):
|
||||
self.assertEqual(make_weighted_conjunction([('', 1)]), parse_prompt(''))
|
||||
|
||||
def test_basic(self):
|
||||
self.assertEqual(make_weighted_conjunction([('fire flames', 1)]), parse_prompt("fire (flames)"))
|
||||
self.assertEqual(make_weighted_conjunction([("fire flames", 1)]), parse_prompt("fire flames"))
|
||||
self.assertEqual(make_weighted_conjunction([('fire flames', 1)]), parse_prompt("fire (flames)"))
|
||||
self.assertEqual(make_weighted_conjunction([("fire, flames", 1)]), parse_prompt("fire, flames"))
|
||||
self.assertEqual(make_weighted_conjunction([("fire, flames , fire", 1)]), parse_prompt("fire, flames , fire"))
|
||||
self.assertEqual(make_weighted_conjunction([("cat hot-dog eating", 1)]), parse_prompt("cat hot-dog eating"))
|
||||
@ -37,14 +37,25 @@ class PromptParserTestCase(unittest.TestCase):
|
||||
|
||||
def test_attention(self):
|
||||
self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("(flames)0.5"))
|
||||
self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("(flames).attend(0.5)"))
|
||||
self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("flames.attend(0.5)"))
|
||||
self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("\"flames\".attend(0.5)"))
|
||||
self.assertEqual(make_weighted_conjunction([('fire flames', 0.5)]), parse_prompt("(fire flames)0.5"))
|
||||
self.assertEqual(make_weighted_conjunction([('fire flames', 0.5)]), parse_prompt("(fire flames).attend(0.5)"))
|
||||
|
||||
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("(flames)+"))
|
||||
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("flames+"))
|
||||
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("\"flames\"+"))
|
||||
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("flames.attend(+)"))
|
||||
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("(flames).attend(+)"))
|
||||
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("\"flames\".attend(+)"))
|
||||
self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("(flames)-"))
|
||||
self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("flames-"))
|
||||
self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("\"flames\"-"))
|
||||
self.assertEqual(make_weighted_conjunction([('fire', 1), ('flames', 0.5)]), parse_prompt("fire (flames)0.5"))
|
||||
self.assertEqual(make_weighted_conjunction([('fire', 1), ('flames', 0.5)]), parse_prompt("fire flames.attend(0.5)"))
|
||||
self.assertEqual(make_weighted_conjunction([('fire', 1), ('flames', 0.5)]), parse_prompt("fire (flames).attend(0.5)"))
|
||||
self.assertEqual(make_weighted_conjunction([('fire', 1), ('flames', 0.5)]), parse_prompt("fire \"flames\".attend(0.5)"))
|
||||
self.assertEqual(make_weighted_conjunction([('flames', pow(1.1, 2))]), parse_prompt("(flames)++"))
|
||||
self.assertEqual(make_weighted_conjunction([('flames', pow(0.9, 2))]), parse_prompt("(flames)--"))
|
||||
self.assertEqual(make_weighted_conjunction([('flowers', pow(0.9, 3)), ('flames', pow(1.1, 3))]), parse_prompt("(flowers)--- flames+++"))
|
||||
@ -102,20 +113,17 @@ class PromptParserTestCase(unittest.TestCase):
|
||||
|
||||
assert_if_prompt_string_not_untouched('a test prompt')
|
||||
assert_if_prompt_string_not_untouched('a badly formed +test prompt')
|
||||
with self.assertRaises(pyparsing.ParseException):
|
||||
parse_prompt('a badly (formed test prompt')
|
||||
assert_if_prompt_string_not_untouched('a badly (formed test prompt')
|
||||
|
||||
#with self.assertRaises(pyparsing.ParseException):
|
||||
with self.assertRaises(pyparsing.ParseException):
|
||||
parse_prompt('a badly (formed +test prompt')
|
||||
assert_if_prompt_string_not_untouched('a badly (formed +test prompt')
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a badly formed +test prompt',1)])]) , parse_prompt('a badly (formed +test )prompt'))
|
||||
with self.assertRaises(pyparsing.ParseException):
|
||||
parse_prompt('(((a badly (formed +test )prompt')
|
||||
with self.assertRaises(pyparsing.ParseException):
|
||||
parse_prompt('(a (ba)dly (f)ormed +test prompt')
|
||||
with self.assertRaises(pyparsing.ParseException):
|
||||
parse_prompt('(a (ba)dly (f)ormed +test +prompt')
|
||||
with self.assertRaises(pyparsing.ParseException):
|
||||
parse_prompt('("((a badly (formed +test ").blend(1.0)')
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('(((a badly formed +test prompt',1)])]) , parse_prompt('(((a badly (formed +test )prompt'))
|
||||
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('(a ba dly f ormed +test prompt',1)])]) , parse_prompt('(a (ba)dly (f)ormed +test prompt'))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('(a ba dly f ormed +test +prompt',1)])]) , parse_prompt('(a (ba)dly (f)ormed +test +prompt'))
|
||||
self.assertEqual(Conjunction([Blend([FlattenedPrompt([Fragment('((a badly (formed +test', 1)])], [1.0])]),
|
||||
parse_prompt('("((a badly (formed +test ").blend(1.0)'))
|
||||
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('hamburger bun', 1)])]),
|
||||
parse_prompt("hamburger ((bun))"))
|
||||
@ -128,6 +136,26 @@ class PromptParserTestCase(unittest.TestCase):
|
||||
|
||||
|
||||
def test_blend(self):
|
||||
self.assertEqual(Conjunction(
|
||||
[Blend([FlattenedPrompt([('mountain', 1.0)]), FlattenedPrompt([('man', 1.0)])], [1.0, 1.0])]),
|
||||
parse_prompt("(\"mountain\", \"man\").blend()")
|
||||
)
|
||||
self.assertEqual(Conjunction(
|
||||
[Blend([FlattenedPrompt([('mountain', 1.0)]), FlattenedPrompt([('man', 1.0)])], [1.0, 1.0])]),
|
||||
parse_prompt("(mountain, man).blend()")
|
||||
)
|
||||
self.assertEqual(Conjunction(
|
||||
[Blend([FlattenedPrompt([('mountain', 1.0)]), FlattenedPrompt([('man', 1.0)])], [1.0, 1.0])]),
|
||||
parse_prompt("((mountain), (man)).blend()")
|
||||
)
|
||||
self.assertEqual(Conjunction(
|
||||
[Blend([FlattenedPrompt([('mountain', 1.0)]), FlattenedPrompt([('tall man', 1.0)])], [1.0, 1.0])]),
|
||||
parse_prompt("((mountain), (tall man)).blend()")
|
||||
)
|
||||
|
||||
with self.assertRaises(PromptParser.ParsingException):
|
||||
print(parse_prompt("((mountain), \"cat.swap(dog)\").blend()"))
|
||||
|
||||
self.assertEqual(Conjunction(
|
||||
[Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('fire flames', 1.0)])], [0.7, 0.3])]),
|
||||
parse_prompt("(\"fire\", \"fire flames\").blend(0.7, 0.3)")
|
||||
@ -167,9 +195,19 @@ class PromptParserTestCase(unittest.TestCase):
|
||||
|
||||
self.assertEqual(
|
||||
Conjunction([Blend([FlattenedPrompt([('mountain , man , hairy', 1)]),
|
||||
FlattenedPrompt([('face, teeth,', 1), ('eyes', 0.9*0.9)])], weights=[1.0,-1.0])]),
|
||||
FlattenedPrompt([('face , teeth ,', 1), ('eyes', 0.9*0.9)])], weights=[1.0,-1.0], normalize_weights=True)]),
|
||||
parse_prompt('("mountain, man, hairy", "face, teeth, eyes--").blend(1,-1)')
|
||||
)
|
||||
self.assertEqual(
|
||||
Conjunction([Blend([FlattenedPrompt([('mountain , man , hairy', 1)]),
|
||||
FlattenedPrompt([('face , teeth ,', 1), ('eyes', 0.9 * 0.9)])], weights=[1.0, -1.0], normalize_weights=False)]),
|
||||
parse_prompt('("mountain, man, hairy", "face, teeth, eyes--").blend(1,-1,no_normalize)')
|
||||
)
|
||||
|
||||
with self.assertRaises(PromptParser.ParsingException):
|
||||
parse_prompt("(\"fire\", \"fire flames\").blend(0.7, 0.3, 0.1)")
|
||||
with self.assertRaises(PromptParser.ParsingException):
|
||||
parse_prompt("(\"fire\", \"fire flames\").blend(0.7)")
|
||||
|
||||
|
||||
def test_nested(self):
|
||||
@ -182,6 +220,9 @@ class PromptParserTestCase(unittest.TestCase):
|
||||
|
||||
def test_cross_attention_control(self):
|
||||
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([CrossAttentionControlSubstitute([Fragment('sun')], [Fragment('moon')])])]),
|
||||
parse_prompt("sun.swap(moon)"))
|
||||
|
||||
self.assertEqual(Conjunction([
|
||||
FlattenedPrompt([Fragment('a', 1),
|
||||
CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)]),
|
||||
@ -259,6 +300,12 @@ class PromptParserTestCase(unittest.TestCase):
|
||||
Fragment(',', 1), Fragment('fire', 2.0)])])
|
||||
self.assertEqual(flames_to_trees_fire, parse_prompt('"(fire (flames)0.5)0.5".swap("(trees)0.7 houses"), (fire)2.0'))
|
||||
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1),
|
||||
CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]),
|
||||
Fragment('eating a', 1),
|
||||
CrossAttentionControlSubstitute([Fragment('hotdog',1)], [Fragment('hotdog', pow(1.1,4))])
|
||||
])]),
|
||||
parse_prompt("a cat.swap(dog) eating a hotdog.swap(hotdog++++)"))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1),
|
||||
CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]),
|
||||
Fragment('eating a', 1),
|
||||
@ -433,6 +480,15 @@ class PromptParserTestCase(unittest.TestCase):
|
||||
|
||||
|
||||
def test_single(self):
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([("mountain man", 1.0)]),
|
||||
FlattenedPrompt([("a person with a hat", 1.0),
|
||||
("riding a", 1.1*1.1),
|
||||
CrossAttentionControlSubstitute(
|
||||
[Fragment("bicycle", pow(1.1,2))],
|
||||
[Fragment("skateboard", pow(1.1,2))])
|
||||
])
|
||||
], weights=[0.5, 0.5]),
|
||||
parse_prompt("(\"mountain man\", \"a person with a hat (riding a bicycle.swap(skateboard))++\").and(0.5, 0.5)"))
|
||||
pass
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user