mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Rebuilt prompt parsing logic
Complete re-write of the prompt parsing logic to be more readable and logical, and therefore also hopefully easier to debug, maintain, and augment. In the process it has also become more robust to badly-formed prompts. Squashed commit of the following: commit 8fcfa88a16e1390d41717e940d72aed64712171c Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Sun Oct 30 17:05:57 2022 +0100 further cleanup commit 1a1fd78bcfeb49d072e3e6d5808aa8df15441629 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Sun Oct 30 16:07:57 2022 +0100 cleanup and document commit 099c9659fa8b8135876f9a5a50fe80b20bc0635c Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Sun Oct 30 15:54:58 2022 +0100 works fully commit 5e6887ea8c25a1e21438ff6defb381fd027d25fd Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Sun Oct 30 15:24:31 2022 +0100 further... commit 492fda120844d9bc1ad4ec7dd408a3374762d0ff Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Sun Oct 30 14:08:57 2022 +0100 getting there... commit c6aab05a8450cc3c95c8691daf38fdc64c74f52d Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Fri Oct 28 14:29:03 2022 +0200 wip doesn't compile commit 5e533f731cfd20cd435330eeb0012e5689e87e81 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Fri Oct 28 13:21:43 2022 +0200 working with CrossAttentionCtonrol but no Attention support yet commit 9678348773431e500e110e8aede99086bb7b5955 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Fri Oct 28 13:04:52 2022 +0200 wip rebuiling prompt parser
This commit is contained in:
parent
349cc25433
commit
e554c2607f
@ -28,7 +28,7 @@ class Prompt():
|
|||||||
def __init__(self, parts: list):
|
def __init__(self, parts: list):
|
||||||
for c in parts:
|
for c in parts:
|
||||||
if type(c) is not Attention and not issubclass(type(c), BaseFragment) and type(c) is not pp.ParseResults:
|
if type(c) is not Attention and not issubclass(type(c), BaseFragment) and type(c) is not pp.ParseResults:
|
||||||
raise PromptParser.ParsingException(f"Prompt cannot contain {type(c).__name__} {c}, only {BaseFragment.__subclasses__()} are allowed")
|
raise PromptParser.ParsingException(f"Prompt cannot contain {type(c).__name__} ({c}), only {[c.__name__ for c in BaseFragment.__subclasses__()]} are allowed")
|
||||||
self.children = parts
|
self.children = parts
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"Prompt:{self.children}"
|
return f"Prompt:{self.children}"
|
||||||
@ -102,12 +102,18 @@ class Attention():
|
|||||||
Do not traverse directly; instead obtain a FlattenedPrompt by calling Flatten() on a top-level Conjunction object.
|
Do not traverse directly; instead obtain a FlattenedPrompt by calling Flatten() on a top-level Conjunction object.
|
||||||
"""
|
"""
|
||||||
def __init__(self, weight: float, children: list):
|
def __init__(self, weight: float, children: list):
|
||||||
|
if type(weight) is not float:
|
||||||
|
raise PromptParser.ParsingException(
|
||||||
|
f"Attention weight must be float (got {type(weight).__name__} {weight})")
|
||||||
self.weight = weight
|
self.weight = weight
|
||||||
|
if type(children) is not list:
|
||||||
|
raise PromptParser.ParsingException(f"cannot make Attention with non-list of children (got {type(children)})")
|
||||||
|
assert(type(children) is list)
|
||||||
self.children = children
|
self.children = children
|
||||||
#print(f"A: requested attention '{children}' to {weight}")
|
#print(f"A: requested attention '{children}' to {weight}")
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"Attention:'{self.children}' @ {self.weight}"
|
return f"Attention:{self.children} * {self.weight}"
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return type(other) is Attention and other.weight == self.weight and other.fragment == self.fragment
|
return type(other) is Attention and other.weight == self.weight and other.fragment == self.fragment
|
||||||
|
|
||||||
@ -136,9 +142,9 @@ class CrossAttentionControlSubstitute(CrossAttentionControlledFragment):
|
|||||||
Fragment('sitting on a car')
|
Fragment('sitting on a car')
|
||||||
])
|
])
|
||||||
"""
|
"""
|
||||||
def __init__(self, original: Union[Fragment, list], edited: Union[Fragment, list], options: dict=None):
|
def __init__(self, original: list, edited: list, options: dict=None):
|
||||||
self.original = original
|
self.original = original
|
||||||
self.edited = edited
|
self.edited = edited if len(edited)>0 else [Fragment('')]
|
||||||
|
|
||||||
default_options = {
|
default_options = {
|
||||||
's_start': 0.0,
|
's_start': 0.0,
|
||||||
@ -190,12 +196,12 @@ class Conjunction():
|
|||||||
"""
|
"""
|
||||||
def __init__(self, prompts: list, weights: list = None):
|
def __init__(self, prompts: list, weights: list = None):
|
||||||
# force everything to be a Prompt
|
# force everything to be a Prompt
|
||||||
#print("making conjunction with", parts)
|
#print("making conjunction with", prompts, "types", [type(p).__name__ for p in prompts])
|
||||||
self.prompts = [x if (type(x) is Prompt
|
self.prompts = [x if (type(x) is Prompt
|
||||||
or type(x) is Blend
|
or type(x) is Blend
|
||||||
or type(x) is FlattenedPrompt)
|
or type(x) is FlattenedPrompt)
|
||||||
else Prompt(x) for x in prompts]
|
else Prompt(x) for x in prompts]
|
||||||
self.weights = [1.0]*len(self.prompts) if weights is None else list(weights)
|
self.weights = [1.0]*len(self.prompts) if (weights is None or len(weights)==0) else list(weights)
|
||||||
if len(self.weights) != len(self.prompts):
|
if len(self.weights) != len(self.prompts):
|
||||||
raise PromptParser.ParsingException(f"while parsing Conjunction: mismatched parts/weights counts {prompts}, {weights}")
|
raise PromptParser.ParsingException(f"while parsing Conjunction: mismatched parts/weights counts {prompts}, {weights}")
|
||||||
self.type = 'AND'
|
self.type = 'AND'
|
||||||
@ -216,6 +222,7 @@ class Blend():
|
|||||||
"""
|
"""
|
||||||
def __init__(self, prompts: list, weights: list[float], normalize_weights: bool=True):
|
def __init__(self, prompts: list, weights: list[float], normalize_weights: bool=True):
|
||||||
#print("making Blend with prompts", prompts, "and weights", weights)
|
#print("making Blend with prompts", prompts, "and weights", weights)
|
||||||
|
weights = [1.0]*len(prompts) if (weights is None or len(weights)==0) else list(weights)
|
||||||
if len(prompts) != len(weights):
|
if len(prompts) != len(weights):
|
||||||
raise PromptParser.ParsingException(f"while parsing Blend: mismatched prompts/weights counts {prompts}, {weights}")
|
raise PromptParser.ParsingException(f"while parsing Blend: mismatched prompts/weights counts {prompts}, {weights}")
|
||||||
for p in prompts:
|
for p in prompts:
|
||||||
@ -244,6 +251,10 @@ class PromptParser():
|
|||||||
class ParsingException(Exception):
|
class ParsingException(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
class UnrecognizedOperatorException(Exception):
|
||||||
|
def __init__(self, operator:str):
|
||||||
|
super().__init__("Unrecognized operator: " + operator)
|
||||||
|
|
||||||
def __init__(self, attention_plus_base=1.1, attention_minus_base=0.9):
|
def __init__(self, attention_plus_base=1.1, attention_minus_base=0.9):
|
||||||
|
|
||||||
self.conjunction, self.prompt = build_parser_syntax(attention_plus_base, attention_minus_base)
|
self.conjunction, self.prompt = build_parser_syntax(attention_plus_base, attention_minus_base)
|
||||||
@ -279,7 +290,7 @@ class PromptParser():
|
|||||||
return Blend(prompts=flattened_prompts, weights=weights, normalize_weights=True)
|
return Blend(prompts=flattened_prompts, weights=weights, normalize_weights=True)
|
||||||
|
|
||||||
|
|
||||||
def flatten(self, root: Conjunction) -> Conjunction:
|
def flatten(self, root: Conjunction, verbose = False) -> Conjunction:
|
||||||
"""
|
"""
|
||||||
Flattening a Conjunction traverses all of the nested tree-like structures in each of its Prompts or Blends,
|
Flattening a Conjunction traverses all of the nested tree-like structures in each of its Prompts or Blends,
|
||||||
producing from each of these walks a linear sequence of Fragment or CrossAttentionControlSubstitute objects
|
producing from each of these walks a linear sequence of Fragment or CrossAttentionControlSubstitute objects
|
||||||
@ -289,8 +300,6 @@ class PromptParser():
|
|||||||
:return: A Conjunction containing the result of flattening each of the prompts in the passed-in root.
|
:return: A Conjunction containing the result of flattening each of the prompts in the passed-in root.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
#print("flattening", root)
|
|
||||||
|
|
||||||
def fuse_fragments(items):
|
def fuse_fragments(items):
|
||||||
# print("fusing fragments in ", items)
|
# print("fusing fragments in ", items)
|
||||||
result = []
|
result = []
|
||||||
@ -313,8 +322,8 @@ class PromptParser():
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def flatten_internal(node, weight_scale, results, prefix):
|
def flatten_internal(node, weight_scale, results, prefix):
|
||||||
#print(prefix + "flattening", node, "...")
|
verbose and print(prefix + "flattening", node, "...")
|
||||||
if type(node) is pp.ParseResults:
|
if type(node) is pp.ParseResults or type(node) is list:
|
||||||
for x in node:
|
for x in node:
|
||||||
results = flatten_internal(x, weight_scale, results, prefix+' pr ')
|
results = flatten_internal(x, weight_scale, results, prefix+' pr ')
|
||||||
#print(prefix, " ParseResults expanded, results is now", results)
|
#print(prefix, " ParseResults expanded, results is now", results)
|
||||||
@ -345,67 +354,59 @@ class PromptParser():
|
|||||||
#print(prefix + "after flattening Prompt, results is", results)
|
#print(prefix + "after flattening Prompt, results is", results)
|
||||||
else:
|
else:
|
||||||
raise PromptParser.ParsingException(f"unhandled node type {type(node)} when flattening {node}")
|
raise PromptParser.ParsingException(f"unhandled node type {type(node)} when flattening {node}")
|
||||||
#print(prefix + "-> after flattening", type(node).__name__, "results is", results)
|
verbose and print(prefix + "-> after flattening", type(node).__name__, "results is", results)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
verbose and print("flattening", root)
|
||||||
|
|
||||||
flattened_parts = []
|
flattened_parts = []
|
||||||
for part in root.prompts:
|
for part in root.prompts:
|
||||||
flattened_parts += flatten_internal(part, 1.0, [], ' C| ')
|
flattened_parts += flatten_internal(part, 1.0, [], ' C| ')
|
||||||
|
|
||||||
#print("flattened to", flattened_parts)
|
verbose and print("flattened to", flattened_parts)
|
||||||
|
|
||||||
weights = root.weights
|
weights = root.weights
|
||||||
return Conjunction(flattened_parts, weights)
|
return Conjunction(flattened_parts, weights)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def build_parser_syntax(attention_plus_base: float, attention_minus_base: float):
|
def build_parser_syntax(attention_plus_base: float, attention_minus_base: float):
|
||||||
|
def make_operator_object(x):
|
||||||
|
#print('making operator for', x)
|
||||||
|
target = x[0]
|
||||||
|
operator = x[1]
|
||||||
|
arguments = x[2]
|
||||||
|
if operator == '.attend':
|
||||||
|
weight_raw = arguments[0]
|
||||||
|
weight = 1.0
|
||||||
|
if type(weight_raw) is float or type(weight_raw) is int:
|
||||||
|
weight = weight_raw
|
||||||
|
elif type(weight_raw) is str:
|
||||||
|
base = attention_plus_base if weight_raw[0] == '+' else attention_minus_base
|
||||||
|
weight = pow(base, len(weight_raw))
|
||||||
|
return Attention(weight=weight, children=[x for x in x[0]])
|
||||||
|
elif operator == '.swap':
|
||||||
|
return CrossAttentionControlSubstitute(target, arguments, x.as_dict())
|
||||||
|
elif operator == '.blend':
|
||||||
|
prompts = [Prompt(p) for p in x[0]]
|
||||||
|
weights_raw = x[2]
|
||||||
|
normalize_weights = True
|
||||||
|
if len(weights_raw) > 0 and weights_raw[-1][0] == 'no_normalize':
|
||||||
|
normalize_weights = False
|
||||||
|
weights_raw = weights_raw[:-1]
|
||||||
|
weights = [float(w[0]) for w in weights_raw]
|
||||||
|
return Blend(prompts=prompts, weights=weights, normalize_weights=normalize_weights)
|
||||||
|
elif operator == '.and' or operator == '.add':
|
||||||
|
prompts = [Prompt(p) for p in x[0]]
|
||||||
|
weights = [float(w[0]) for w in x[2]]
|
||||||
|
return Conjunction(prompts=prompts, weights=weights)
|
||||||
|
|
||||||
lparen = pp.Literal("(").suppress()
|
raise PromptParser.UnrecognizedOperatorException(operator)
|
||||||
rparen = pp.Literal(")").suppress()
|
|
||||||
quotes = pp.Literal('"').suppress()
|
|
||||||
comma = pp.Literal(",").suppress()
|
|
||||||
|
|
||||||
# accepts int or float notation, always maps to float
|
def parse_fragment_str(x, expression: pp.ParseExpression, in_quotes: bool = False, in_parens: bool = False):
|
||||||
number = pp.pyparsing_common.real | \
|
|
||||||
pp.Combine(pp.Optional("-")+pp.Word(pp.nums)).set_parse_action(pp.token_map(float))
|
|
||||||
|
|
||||||
attention = pp.Forward()
|
|
||||||
quoted_fragment = pp.Forward()
|
|
||||||
parenthesized_fragment = pp.Forward()
|
|
||||||
cross_attention_substitute = pp.Forward()
|
|
||||||
|
|
||||||
def make_text_fragment(x):
|
|
||||||
#print("### making fragment for", x)
|
|
||||||
if type(x[0]) is Fragment:
|
|
||||||
assert(False)
|
|
||||||
if type(x) is str:
|
|
||||||
return Fragment(x)
|
|
||||||
elif type(x) is pp.ParseResults or type(x) is list:
|
|
||||||
#print(f'converting {type(x).__name__} to Fragment')
|
|
||||||
return Fragment(' '.join([s for s in x]))
|
|
||||||
else:
|
|
||||||
raise PromptParser.ParsingException("Cannot make fragment from " + str(x))
|
|
||||||
|
|
||||||
def build_escaped_word_parser_charbychar(escaped_chars_to_ignore: str):
|
|
||||||
escapes = []
|
|
||||||
for c in escaped_chars_to_ignore:
|
|
||||||
escapes.append(pp.Literal('\\'+c))
|
|
||||||
return pp.Combine(pp.OneOrMore(
|
|
||||||
pp.MatchFirst(escapes + [pp.CharsNotIn(
|
|
||||||
string.whitespace + escaped_chars_to_ignore,
|
|
||||||
exact=1
|
|
||||||
)])
|
|
||||||
))
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def parse_fragment_str(x, in_quotes: bool=False, in_parens: bool=False):
|
|
||||||
#print(f"parsing fragment string for {x}")
|
#print(f"parsing fragment string for {x}")
|
||||||
fragment_string = x[0]
|
fragment_string = x[0]
|
||||||
#print(f"ppparsing fragment string \"{fragment_string}\"")
|
|
||||||
|
|
||||||
if len(fragment_string.strip()) == 0:
|
if len(fragment_string.strip()) == 0:
|
||||||
return Fragment('')
|
return Fragment('')
|
||||||
|
|
||||||
@ -413,234 +414,198 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float)
|
|||||||
# escape unescaped quotes
|
# escape unescaped quotes
|
||||||
fragment_string = fragment_string.replace('"', '\\"')
|
fragment_string = fragment_string.replace('"', '\\"')
|
||||||
|
|
||||||
#fragment_parser = pp.Group(pp.OneOrMore(attention | cross_attention_substitute | (greedy_word.set_parse_action(make_text_fragment))))
|
|
||||||
try:
|
try:
|
||||||
result = pp.Group(pp.MatchFirst([
|
result = (expression + pp.StringEnd()).parse_string(fragment_string)
|
||||||
pp.OneOrMore(quoted_fragment | attention | unquoted_word).set_name('pf_str_qfuq'),
|
|
||||||
pp.Empty().set_parse_action(make_text_fragment) + pp.StringEnd()
|
|
||||||
])).set_name('blend-result').set_debug(False).parse_string(fragment_string)
|
|
||||||
#print("parsed to", result)
|
#print("parsed to", result)
|
||||||
return result
|
return result
|
||||||
except pp.ParseException as e:
|
except pp.ParseException as e:
|
||||||
#print("parse_fragment_str couldn't parse prompt string:", e)
|
#print("parse_fragment_str couldn't parse prompt string:", e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
# meaningful symbols
|
||||||
|
lparen = pp.Literal("(").suppress()
|
||||||
|
rparen = pp.Literal(")").suppress()
|
||||||
|
quote = pp.Literal('"').suppress()
|
||||||
|
comma = pp.Literal(",").suppress()
|
||||||
|
dot = pp.Literal(".").suppress()
|
||||||
|
equals = pp.Literal("=").suppress()
|
||||||
|
|
||||||
|
escaped_lparen = pp.Literal('\\(')
|
||||||
|
escaped_rparen = pp.Literal('\\)')
|
||||||
|
escaped_quote = pp.Literal('\\"')
|
||||||
|
escaped_comma = pp.Literal('\\,')
|
||||||
|
escaped_dot = pp.Literal('\\.')
|
||||||
|
escaped_plus = pp.Literal('\\+')
|
||||||
|
escaped_minus = pp.Literal('\\-')
|
||||||
|
escaped_equals = pp.Literal('\\=')
|
||||||
|
|
||||||
|
syntactic_symbols = {
|
||||||
|
'(': escaped_lparen,
|
||||||
|
')': escaped_rparen,
|
||||||
|
'"': escaped_quote,
|
||||||
|
',': escaped_comma,
|
||||||
|
'.': escaped_dot,
|
||||||
|
'+': escaped_plus,
|
||||||
|
'-': escaped_minus,
|
||||||
|
'=': escaped_equals,
|
||||||
|
}
|
||||||
|
syntactic_chars = "".join(syntactic_symbols.keys())
|
||||||
|
|
||||||
|
# accepts int or float notation, always maps to float
|
||||||
|
number = pp.pyparsing_common.real | \
|
||||||
|
pp.Combine(pp.Optional("-")+pp.Word(pp.nums)).set_parse_action(pp.token_map(float))
|
||||||
|
|
||||||
|
# for options
|
||||||
|
keyword = pp.Word(pp.alphanums + '_')
|
||||||
|
|
||||||
|
# a word that absolutely does not contain any meaningful syntax
|
||||||
|
non_syntax_word = pp.Combine(pp.OneOrMore(pp.MatchFirst([
|
||||||
|
pp.Or(syntactic_symbols.values()),
|
||||||
|
pp.one_of(['-', '+']) + pp.NotAny(pp.White() | pp.Char(syntactic_chars) | pp.StringEnd()),
|
||||||
|
# build character-by-character
|
||||||
|
pp.CharsNotIn(string.whitespace + syntactic_chars, exact=1)
|
||||||
|
])))
|
||||||
|
non_syntax_word.set_parse_action(lambda x: [Fragment(t) for t in x])
|
||||||
|
non_syntax_word.set_name('non_syntax_word')
|
||||||
|
non_syntax_word.set_debug(False)
|
||||||
|
|
||||||
|
# a word that can contain any character at all - greedily consumes syntax, so use with care
|
||||||
|
free_word = pp.CharsNotIn(string.whitespace).set_parse_action(lambda x: Fragment(x[0]))
|
||||||
|
free_word.set_name('free_word')
|
||||||
|
free_word.set_debug(False)
|
||||||
|
|
||||||
|
|
||||||
|
# ok here we go. forward declare some things..
|
||||||
|
attention = pp.Forward()
|
||||||
|
cross_attention_substitute = pp.Forward()
|
||||||
|
parenthesized_fragment = pp.Forward()
|
||||||
|
quoted_fragment = pp.Forward()
|
||||||
|
|
||||||
|
# the types of things that can go into a fragment, consisting of syntax-full and/or strictly syntax-free components
|
||||||
|
fragment_part_expressions = [
|
||||||
|
attention,
|
||||||
|
cross_attention_substitute,
|
||||||
|
parenthesized_fragment,
|
||||||
|
quoted_fragment,
|
||||||
|
non_syntax_word
|
||||||
|
]
|
||||||
|
# a fragment that is permitted to contain commas
|
||||||
|
fragment_including_commas = pp.ZeroOrMore(pp.MatchFirst(
|
||||||
|
fragment_part_expressions + [
|
||||||
|
pp.Literal(',').set_parse_action(lambda x: Fragment(x[0]))
|
||||||
|
]
|
||||||
|
))
|
||||||
|
# a fragment that is not permitted to contain commas
|
||||||
|
fragment_excluding_commas = pp.ZeroOrMore(pp.MatchFirst(
|
||||||
|
fragment_part_expressions
|
||||||
|
))
|
||||||
|
|
||||||
|
# a fragment in double quotes (may be nested)
|
||||||
quoted_fragment << pp.QuotedString(quote_char='"', esc_char=None, esc_quote='\\"')
|
quoted_fragment << pp.QuotedString(quote_char='"', esc_char=None, esc_quote='\\"')
|
||||||
quoted_fragment.set_parse_action(lambda x: parse_fragment_str(x, in_quotes=True)).set_name('quoted_fragment')
|
quoted_fragment.set_parse_action(lambda x: parse_fragment_str(x, fragment_including_commas, in_quotes=True))
|
||||||
|
|
||||||
escaped_quote = pp.Literal('\\"')#.set_parse_action(lambda x: '"')
|
# a fragment inside parentheses (may be nested)
|
||||||
escaped_lparen = pp.Literal('\\(')#.set_parse_action(lambda x: '(')
|
parenthesized_fragment << (lparen + fragment_including_commas + rparen)
|
||||||
escaped_rparen = pp.Literal('\\)')#.set_parse_action(lambda x: ')')
|
parenthesized_fragment.set_name('parenthesized_fragment')
|
||||||
escaped_backslash = pp.Literal('\\\\')#.set_parse_action(lambda x: '"')
|
parenthesized_fragment.set_debug(False)
|
||||||
|
|
||||||
empty = (
|
# a string of the form (<keyword>=<float|keyword> | <float> | <keyword>) where keyword is alphanumeric + '_'
|
||||||
(lparen + pp.ZeroOrMore(pp.Word(string.whitespace)) + rparen) |
|
option = pp.Group(pp.MatchFirst([
|
||||||
(quotes + pp.ZeroOrMore(pp.Word(string.whitespace)) + quotes)).set_debug(False).set_name('empty')
|
keyword + equals + (number | keyword), # option=value
|
||||||
|
number.copy().set_parse_action(pp.token_map(str)), # weight
|
||||||
|
keyword # flag
|
||||||
def not_ends_with_swap(x):
|
|
||||||
#print("trying to match:", x)
|
|
||||||
return not x[0].endswith('.swap')
|
|
||||||
|
|
||||||
unquoted_word = (pp.Combine(pp.OneOrMore(
|
|
||||||
escaped_rparen | escaped_lparen | escaped_quote | escaped_backslash |
|
|
||||||
(pp.CharsNotIn(string.whitespace + '\\"()', exact=1)
|
|
||||||
)))
|
|
||||||
# don't whitespace when the next word starts with +, eg "badly +formed"
|
|
||||||
+ (pp.White().suppress() |
|
|
||||||
# don't eat +/-
|
|
||||||
pp.NotAny(pp.Word('+') | pp.Word('-'))
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
unquoted_word.set_parse_action(make_text_fragment).set_name('unquoted_word').set_debug(False)
|
|
||||||
#print(unquoted_fragment.parse_string("cat.swap(dog)"))
|
|
||||||
|
|
||||||
parenthesized_fragment << (lparen +
|
|
||||||
pp.Or([
|
|
||||||
(parenthesized_fragment),
|
|
||||||
(quoted_fragment.copy().set_parse_action(lambda x: parse_fragment_str(x, in_quotes=True)).set_debug(False)).set_name('-quoted_paren_internal').set_debug(False),
|
|
||||||
(pp.Combine(pp.OneOrMore(
|
|
||||||
escaped_quote | escaped_lparen | escaped_rparen | escaped_backslash |
|
|
||||||
pp.CharsNotIn(string.whitespace + '\\"()', exact=1) |
|
|
||||||
pp.White()
|
|
||||||
)).set_name('--combined').set_parse_action(lambda x: parse_fragment_str(x, in_parens=True)).set_debug(False)),
|
|
||||||
pp.Empty()
|
|
||||||
]) + rparen)
|
|
||||||
parenthesized_fragment.set_name('parenthesized_fragment').set_debug(False)
|
|
||||||
|
|
||||||
debug_attention = False
|
|
||||||
# attention control of the form (phrase)+ / (phrase)+ / (phrase)<weight>
|
|
||||||
# phrase can be multiple words, can have multiple +/- signs to increase the effect or type a floating point or integer weight
|
|
||||||
attention_with_parens = pp.Forward()
|
|
||||||
attention_without_parens = pp.Forward()
|
|
||||||
|
|
||||||
attention_with_parens_foot = (number | pp.Word('+') | pp.Word('-'))\
|
|
||||||
.set_name("attention_foot")\
|
|
||||||
.set_debug(False)
|
|
||||||
attention_with_parens <<= pp.Group(
|
|
||||||
lparen +
|
|
||||||
pp.ZeroOrMore(quoted_fragment | attention_with_parens | parenthesized_fragment | cross_attention_substitute | attention_without_parens |
|
|
||||||
(pp.Empty() + build_escaped_word_parser_charbychar('()')).set_name('undecorated_word').set_debug(debug_attention)#.set_parse_action(lambda t: t[0])
|
|
||||||
)
|
|
||||||
+ rparen + attention_with_parens_foot)
|
|
||||||
attention_with_parens.set_name('attention_with_parens').set_debug(debug_attention)
|
|
||||||
|
|
||||||
attention_without_parens_foot = (pp.NotAny(pp.White()) + pp.Or([pp.Word('+'), pp.Word('-')]) + pp.FollowedBy(pp.StringEnd() | pp.White() | pp.Literal('(') | pp.Literal(')') | pp.Literal(',') | pp.Literal('"')) ).set_name('attention_without_parens_foots')
|
|
||||||
attention_without_parens <<= pp.Group(pp.MatchFirst([
|
|
||||||
quoted_fragment.copy().set_name('attention_quoted_fragment_without_parens').set_debug(debug_attention) + attention_without_parens_foot,
|
|
||||||
pp.Combine(build_escaped_word_parser_charbychar('()+-')).set_name('attention_word_without_parens').set_debug(debug_attention)#.set_parse_action(lambda x: print('escapéd', x))
|
|
||||||
+ attention_without_parens_foot#.leave_whitespace()
|
|
||||||
]))
|
]))
|
||||||
attention_without_parens.set_name('attention_without_parens').set_debug(debug_attention)
|
# options for an operator, eg "s_start=0.1, 0.3, no_normalize"
|
||||||
|
options = pp.Dict(pp.Optional(pp.delimited_list(option)))
|
||||||
|
options.set_name('options')
|
||||||
|
options.set_debug(False)
|
||||||
|
|
||||||
|
# a fragment which can be used as the target for an operator - either quoted or in parentheses, or a bare vanilla word
|
||||||
|
potential_operator_target = (quoted_fragment | parenthesized_fragment | non_syntax_word)
|
||||||
|
|
||||||
attention << pp.MatchFirst([attention_with_parens,
|
# a fragment whose weight has been increased or decreased by a given amount
|
||||||
attention_without_parens
|
attention_weight_operator = pp.Word('+') | pp.Word('-') | number
|
||||||
])
|
attention_explicit = (
|
||||||
|
pp.Group(potential_operator_target)
|
||||||
|
+ pp.Literal('.attend')
|
||||||
|
+ lparen
|
||||||
|
+ pp.Group(attention_weight_operator)
|
||||||
|
+ rparen
|
||||||
|
)
|
||||||
|
attention_explicit.set_parse_action(make_operator_object)
|
||||||
|
attention_implicit = (
|
||||||
|
pp.Group(potential_operator_target)
|
||||||
|
+ pp.NotAny(pp.White()) # do not permit whitespace between term and operator
|
||||||
|
+ pp.Group(attention_weight_operator)
|
||||||
|
)
|
||||||
|
attention_implicit.set_parse_action(lambda x: make_operator_object([x[0], '.attend', x[1]]))
|
||||||
|
attention << (attention_explicit | attention_implicit)
|
||||||
attention.set_name('attention')
|
attention.set_name('attention')
|
||||||
|
attention.set_debug(False)
|
||||||
|
|
||||||
def make_attention(x):
|
# cross-attention control by swapping one fragment for another
|
||||||
#print("entered make_attention with", x)
|
cross_attention_substitute << (
|
||||||
children = x[0][:-1]
|
pp.Group(potential_operator_target).set_name('ca-target').set_debug(False)
|
||||||
weight_raw = x[0][-1]
|
+ pp.Literal(".swap").set_name('ca-operator').set_debug(False)
|
||||||
weight = 1.0
|
+ lparen
|
||||||
if type(weight_raw) is float or type(weight_raw) is int:
|
+ pp.Group(fragment_excluding_commas).set_name('ca-replacement').set_debug(False)
|
||||||
weight = weight_raw
|
+ pp.Optional(comma + options).set_name('ca-options').set_debug(False)
|
||||||
elif type(weight_raw) is str:
|
+ rparen
|
||||||
base = attention_plus_base if weight_raw[0] == '+' else attention_minus_base
|
)
|
||||||
weight = pow(base, len(weight_raw))
|
cross_attention_substitute.set_name('cross_attention_substitute')
|
||||||
|
cross_attention_substitute.set_debug(False)
|
||||||
#print("making Attention from", children, "with weight", weight)
|
cross_attention_substitute.set_parse_action(make_operator_object)
|
||||||
|
|
||||||
return Attention(weight=weight, children=[(Fragment(x) if type(x) is str else x) for x in children])
|
|
||||||
|
|
||||||
attention_with_parens.set_parse_action(make_attention)
|
|
||||||
attention_without_parens.set_parse_action(make_attention)
|
|
||||||
|
|
||||||
#print("parsing test:", attention_with_parens.parse_string("mountain (man)1.1"))
|
|
||||||
|
|
||||||
# cross-attention control
|
|
||||||
empty_string = ((lparen + rparen) |
|
|
||||||
pp.Literal('""').suppress() |
|
|
||||||
(lparen + pp.Literal('""').suppress() + rparen)
|
|
||||||
).set_parse_action(lambda x: Fragment(""))
|
|
||||||
empty_string.set_name('empty_string')
|
|
||||||
|
|
||||||
# cross attention control
|
|
||||||
debug_cross_attention_control = False
|
|
||||||
original_fragment = pp.MatchFirst([
|
|
||||||
quoted_fragment.set_debug(debug_cross_attention_control),
|
|
||||||
parenthesized_fragment.set_debug(debug_cross_attention_control),
|
|
||||||
pp.Combine(pp.OneOrMore(pp.CharsNotIn(string.whitespace + '.', exact=1))).set_parse_action(make_text_fragment) + pp.FollowedBy(".swap"),
|
|
||||||
empty_string.set_debug(debug_cross_attention_control),
|
|
||||||
])
|
|
||||||
# support keyword=number arguments
|
|
||||||
cross_attention_option_keyword = pp.Or([pp.Keyword("s_start"), pp.Keyword("s_end"), pp.Keyword("t_start"), pp.Keyword("t_end"), pp.Keyword("shape_freedom")])
|
|
||||||
cross_attention_option = pp.Group(cross_attention_option_keyword + pp.Literal("=").suppress() + number)
|
|
||||||
edited_fragment = pp.MatchFirst([
|
|
||||||
(lparen + rparen).set_parse_action(lambda x: Fragment('')),
|
|
||||||
lparen +
|
|
||||||
(quoted_fragment | attention |
|
|
||||||
pp.Group(pp.ZeroOrMore(build_escaped_word_parser_charbychar(',)').set_parse_action(make_text_fragment)))
|
|
||||||
) +
|
|
||||||
pp.Dict(pp.ZeroOrMore(comma + cross_attention_option)) +
|
|
||||||
rparen,
|
|
||||||
parenthesized_fragment
|
|
||||||
])
|
|
||||||
cross_attention_substitute << original_fragment + pp.Literal(".swap").set_debug(False).suppress() + edited_fragment
|
|
||||||
|
|
||||||
original_fragment.set_name('original_fragment').set_debug(debug_cross_attention_control)
|
|
||||||
edited_fragment.set_name('edited_fragment').set_debug(debug_cross_attention_control)
|
|
||||||
cross_attention_substitute.set_name('cross_attention_substitute').set_debug(debug_cross_attention_control)
|
|
||||||
|
|
||||||
def make_cross_attention_substitute(x):
|
|
||||||
#print("making cacs for", x[0], "->", x[1], "with options", x.as_dict())
|
|
||||||
#if len(x>2):
|
|
||||||
cacs = CrossAttentionControlSubstitute(x[0], x[1], options=x.as_dict())
|
|
||||||
#print("made", cacs)
|
|
||||||
return cacs
|
|
||||||
cross_attention_substitute.set_parse_action(make_cross_attention_substitute)
|
|
||||||
|
|
||||||
|
|
||||||
# root prompt definition
|
# an entire self-contained prompt, which can be used in a Blend or Conjunction
|
||||||
debug_root_prompt = False
|
prompt = pp.ZeroOrMore(pp.MatchFirst([
|
||||||
prompt = (pp.OneOrMore(pp.MatchFirst([cross_attention_substitute.set_debug(debug_root_prompt),
|
cross_attention_substitute,
|
||||||
attention.set_debug(debug_root_prompt),
|
attention,
|
||||||
quoted_fragment.set_debug(debug_root_prompt),
|
quoted_fragment,
|
||||||
parenthesized_fragment.set_debug(debug_root_prompt),
|
parenthesized_fragment,
|
||||||
unquoted_word.set_debug(debug_root_prompt),
|
free_word,
|
||||||
empty.set_parse_action(make_text_fragment).set_debug(debug_root_prompt)])
|
pp.White().suppress()
|
||||||
) + pp.StringEnd()) \
|
]))
|
||||||
.set_name('prompt') \
|
quoted_prompt = quoted_fragment.copy().set_parse_action(lambda x: parse_fragment_str(x, prompt, in_quotes=True))
|
||||||
.set_parse_action(lambda x: Prompt(x)) \
|
|
||||||
.set_debug(debug_root_prompt)
|
|
||||||
|
|
||||||
#print("parsing test:", prompt.parse_string("spaced eyes--"))
|
|
||||||
#print("parsing test:", prompt.parse_string("eyes--"))
|
|
||||||
|
|
||||||
# weighted blend of prompts
|
# a blend/lerp between the feature vectors for two or more prompts
|
||||||
# ("promptA", "promptB").blend(a, b) where "promptA" and "promptB" are valid prompts and a and b are float or
|
blend = (
|
||||||
# int weights.
|
lparen
|
||||||
# can specify more terms eg ("promptA", "promptB", "promptC").blend(a,b,c)
|
+ pp.Group(pp.delimited_list(pp.Group(potential_operator_target | quoted_prompt), min=1)).set_name('bl-target').set_debug(False)
|
||||||
|
+ rparen
|
||||||
|
+ pp.Literal(".blend").set_name('bl-operator').set_debug(False)
|
||||||
|
+ lparen
|
||||||
|
+ pp.Group(options).set_name('bl-options').set_debug(False)
|
||||||
|
+ rparen
|
||||||
|
)
|
||||||
|
blend.set_name('blend')
|
||||||
|
blend.set_debug(False)
|
||||||
|
blend.set_parse_action(make_operator_object)
|
||||||
|
|
||||||
def make_prompt_from_quoted_string(x):
|
# an operator to direct stable diffusion to step multiple times, once for each target, and then add the results together with different weights
|
||||||
#print(' got quoted prompt', x)
|
explicit_conjunction = (
|
||||||
|
lparen
|
||||||
|
+ pp.Group(pp.delimited_list(pp.Group(potential_operator_target | quoted_prompt), min=1)).set_name('cj-target').set_debug(False)
|
||||||
|
+ rparen
|
||||||
|
+ pp.one_of([".and", ".add"]).set_name('cj-operator').set_debug(False)
|
||||||
|
+ lparen
|
||||||
|
+ pp.Group(options).set_name('cj-options').set_debug(False)
|
||||||
|
+ rparen
|
||||||
|
)
|
||||||
|
explicit_conjunction.set_name('explicit_conjunction')
|
||||||
|
explicit_conjunction.set_debug(False)
|
||||||
|
explicit_conjunction.set_parse_action(make_operator_object)
|
||||||
|
|
||||||
x_unquoted = x[0][1:-1]
|
# by default a prompt consists of a Conjunction with a single term
|
||||||
if len(x_unquoted.strip()) == 0:
|
implicit_conjunction = (blend | pp.Group(prompt)) + pp.StringEnd()
|
||||||
# print(' b : just an empty string')
|
|
||||||
return Prompt([Fragment('')])
|
|
||||||
#print(f' b parsing \'{x_unquoted}\'')
|
|
||||||
x_parsed = prompt.parse_string(x_unquoted)
|
|
||||||
#print(" quoted prompt was parsed to", type(x_parsed),":", x_parsed)
|
|
||||||
return x_parsed[0]
|
|
||||||
|
|
||||||
quoted_prompt = pp.dbl_quoted_string.set_parse_action(make_prompt_from_quoted_string)
|
|
||||||
quoted_prompt.set_name('quoted_prompt')
|
|
||||||
|
|
||||||
debug_blend=False
|
|
||||||
blend_terms = pp.delimited_list(quoted_prompt).set_name('blend_terms').set_debug(debug_blend)
|
|
||||||
blend_weights = (pp.delimited_list(number) + pp.Optional(pp.Char(",").suppress() + "no_normalize")).set_name('blend_weights').set_debug(debug_blend)
|
|
||||||
blend = pp.Group(lparen + pp.Group(blend_terms) + rparen
|
|
||||||
+ pp.Literal(".blend").suppress()
|
|
||||||
+ lparen + pp.Group(blend_weights) + rparen).set_name('blend')
|
|
||||||
blend.set_debug(debug_blend)
|
|
||||||
|
|
||||||
def make_blend(x):
|
|
||||||
prompts = x[0][0]
|
|
||||||
weights = x[0][1]
|
|
||||||
normalize = True
|
|
||||||
if weights[-1] == 'no_normalize':
|
|
||||||
normalize = False
|
|
||||||
weights = weights[:-1]
|
|
||||||
return Blend(prompts=prompts, weights=weights, normalize_weights=normalize)
|
|
||||||
|
|
||||||
blend.set_parse_action(make_blend)
|
|
||||||
|
|
||||||
conjunction_terms = blend_terms.copy().set_name('conjunction_terms')
|
|
||||||
conjunction_weights = blend_weights.copy().set_name('conjunction_weights')
|
|
||||||
conjunction_with_parens_and_quotes = pp.Group(lparen + pp.Group(conjunction_terms) + rparen
|
|
||||||
+ pp.Literal(".and").suppress()
|
|
||||||
+ lparen + pp.Optional(pp.Group(conjunction_weights)) + rparen).set_name('conjunction')
|
|
||||||
def make_conjunction(x):
|
|
||||||
parts_raw = x[0][0]
|
|
||||||
weights = x[0][1] if len(x[0])>1 else [1.0]*len(parts_raw)
|
|
||||||
parts = [part for part in parts_raw]
|
|
||||||
return Conjunction(parts, weights)
|
|
||||||
conjunction_with_parens_and_quotes.set_parse_action(make_conjunction)
|
|
||||||
|
|
||||||
implicit_conjunction = pp.OneOrMore(blend | prompt).set_name('implicit_conjunction')
|
|
||||||
implicit_conjunction.set_parse_action(lambda x: Conjunction(x))
|
implicit_conjunction.set_parse_action(lambda x: Conjunction(x))
|
||||||
|
|
||||||
conjunction = conjunction_with_parens_and_quotes | implicit_conjunction
|
conjunction = (explicit_conjunction | implicit_conjunction)
|
||||||
conjunction.set_debug(False)
|
|
||||||
|
|
||||||
# top-level is a conjunction of one or more blends or prompts
|
|
||||||
return conjunction, prompt
|
return conjunction, prompt
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def split_weighted_subprompts(text, skip_normalize=False)->list:
|
def split_weighted_subprompts(text, skip_normalize=False)->list:
|
||||||
"""
|
"""
|
||||||
Legacy blend parsing.
|
Legacy blend parsing.
|
||||||
|
@ -28,8 +28,8 @@ class PromptParserTestCase(unittest.TestCase):
|
|||||||
self.assertEqual(make_weighted_conjunction([('', 1)]), parse_prompt(''))
|
self.assertEqual(make_weighted_conjunction([('', 1)]), parse_prompt(''))
|
||||||
|
|
||||||
def test_basic(self):
|
def test_basic(self):
|
||||||
self.assertEqual(make_weighted_conjunction([('fire flames', 1)]), parse_prompt("fire (flames)"))
|
|
||||||
self.assertEqual(make_weighted_conjunction([("fire flames", 1)]), parse_prompt("fire flames"))
|
self.assertEqual(make_weighted_conjunction([("fire flames", 1)]), parse_prompt("fire flames"))
|
||||||
|
self.assertEqual(make_weighted_conjunction([('fire flames', 1)]), parse_prompt("fire (flames)"))
|
||||||
self.assertEqual(make_weighted_conjunction([("fire, flames", 1)]), parse_prompt("fire, flames"))
|
self.assertEqual(make_weighted_conjunction([("fire, flames", 1)]), parse_prompt("fire, flames"))
|
||||||
self.assertEqual(make_weighted_conjunction([("fire, flames , fire", 1)]), parse_prompt("fire, flames , fire"))
|
self.assertEqual(make_weighted_conjunction([("fire, flames , fire", 1)]), parse_prompt("fire, flames , fire"))
|
||||||
self.assertEqual(make_weighted_conjunction([("cat hot-dog eating", 1)]), parse_prompt("cat hot-dog eating"))
|
self.assertEqual(make_weighted_conjunction([("cat hot-dog eating", 1)]), parse_prompt("cat hot-dog eating"))
|
||||||
@ -37,14 +37,25 @@ class PromptParserTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
def test_attention(self):
|
def test_attention(self):
|
||||||
self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("(flames)0.5"))
|
self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("(flames)0.5"))
|
||||||
|
self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("(flames).attend(0.5)"))
|
||||||
|
self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("flames.attend(0.5)"))
|
||||||
|
self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("\"flames\".attend(0.5)"))
|
||||||
self.assertEqual(make_weighted_conjunction([('fire flames', 0.5)]), parse_prompt("(fire flames)0.5"))
|
self.assertEqual(make_weighted_conjunction([('fire flames', 0.5)]), parse_prompt("(fire flames)0.5"))
|
||||||
|
self.assertEqual(make_weighted_conjunction([('fire flames', 0.5)]), parse_prompt("(fire flames).attend(0.5)"))
|
||||||
|
|
||||||
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("(flames)+"))
|
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("(flames)+"))
|
||||||
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("flames+"))
|
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("flames+"))
|
||||||
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("\"flames\"+"))
|
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("\"flames\"+"))
|
||||||
|
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("flames.attend(+)"))
|
||||||
|
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("(flames).attend(+)"))
|
||||||
|
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("\"flames\".attend(+)"))
|
||||||
self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("(flames)-"))
|
self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("(flames)-"))
|
||||||
self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("flames-"))
|
self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("flames-"))
|
||||||
self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("\"flames\"-"))
|
self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("\"flames\"-"))
|
||||||
self.assertEqual(make_weighted_conjunction([('fire', 1), ('flames', 0.5)]), parse_prompt("fire (flames)0.5"))
|
self.assertEqual(make_weighted_conjunction([('fire', 1), ('flames', 0.5)]), parse_prompt("fire (flames)0.5"))
|
||||||
|
self.assertEqual(make_weighted_conjunction([('fire', 1), ('flames', 0.5)]), parse_prompt("fire flames.attend(0.5)"))
|
||||||
|
self.assertEqual(make_weighted_conjunction([('fire', 1), ('flames', 0.5)]), parse_prompt("fire (flames).attend(0.5)"))
|
||||||
|
self.assertEqual(make_weighted_conjunction([('fire', 1), ('flames', 0.5)]), parse_prompt("fire \"flames\".attend(0.5)"))
|
||||||
self.assertEqual(make_weighted_conjunction([('flames', pow(1.1, 2))]), parse_prompt("(flames)++"))
|
self.assertEqual(make_weighted_conjunction([('flames', pow(1.1, 2))]), parse_prompt("(flames)++"))
|
||||||
self.assertEqual(make_weighted_conjunction([('flames', pow(0.9, 2))]), parse_prompt("(flames)--"))
|
self.assertEqual(make_weighted_conjunction([('flames', pow(0.9, 2))]), parse_prompt("(flames)--"))
|
||||||
self.assertEqual(make_weighted_conjunction([('flowers', pow(0.9, 3)), ('flames', pow(1.1, 3))]), parse_prompt("(flowers)--- flames+++"))
|
self.assertEqual(make_weighted_conjunction([('flowers', pow(0.9, 3)), ('flames', pow(1.1, 3))]), parse_prompt("(flowers)--- flames+++"))
|
||||||
@ -102,20 +113,17 @@ class PromptParserTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
assert_if_prompt_string_not_untouched('a test prompt')
|
assert_if_prompt_string_not_untouched('a test prompt')
|
||||||
assert_if_prompt_string_not_untouched('a badly formed +test prompt')
|
assert_if_prompt_string_not_untouched('a badly formed +test prompt')
|
||||||
with self.assertRaises(pyparsing.ParseException):
|
assert_if_prompt_string_not_untouched('a badly (formed test prompt')
|
||||||
parse_prompt('a badly (formed test prompt')
|
|
||||||
#with self.assertRaises(pyparsing.ParseException):
|
#with self.assertRaises(pyparsing.ParseException):
|
||||||
with self.assertRaises(pyparsing.ParseException):
|
assert_if_prompt_string_not_untouched('a badly (formed +test prompt')
|
||||||
parse_prompt('a badly (formed +test prompt')
|
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a badly formed +test prompt',1)])]) , parse_prompt('a badly (formed +test )prompt'))
|
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a badly formed +test prompt',1)])]) , parse_prompt('a badly (formed +test )prompt'))
|
||||||
with self.assertRaises(pyparsing.ParseException):
|
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('(((a badly formed +test prompt',1)])]) , parse_prompt('(((a badly (formed +test )prompt'))
|
||||||
parse_prompt('(((a badly (formed +test )prompt')
|
|
||||||
with self.assertRaises(pyparsing.ParseException):
|
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('(a ba dly f ormed +test prompt',1)])]) , parse_prompt('(a (ba)dly (f)ormed +test prompt'))
|
||||||
parse_prompt('(a (ba)dly (f)ormed +test prompt')
|
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('(a ba dly f ormed +test +prompt',1)])]) , parse_prompt('(a (ba)dly (f)ormed +test +prompt'))
|
||||||
with self.assertRaises(pyparsing.ParseException):
|
self.assertEqual(Conjunction([Blend([FlattenedPrompt([Fragment('((a badly (formed +test', 1)])], [1.0])]),
|
||||||
parse_prompt('(a (ba)dly (f)ormed +test +prompt')
|
parse_prompt('("((a badly (formed +test ").blend(1.0)'))
|
||||||
with self.assertRaises(pyparsing.ParseException):
|
|
||||||
parse_prompt('("((a badly (formed +test ").blend(1.0)')
|
|
||||||
|
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('hamburger bun', 1)])]),
|
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('hamburger bun', 1)])]),
|
||||||
parse_prompt("hamburger ((bun))"))
|
parse_prompt("hamburger ((bun))"))
|
||||||
@ -128,6 +136,26 @@ class PromptParserTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
def test_blend(self):
|
def test_blend(self):
|
||||||
|
self.assertEqual(Conjunction(
|
||||||
|
[Blend([FlattenedPrompt([('mountain', 1.0)]), FlattenedPrompt([('man', 1.0)])], [1.0, 1.0])]),
|
||||||
|
parse_prompt("(\"mountain\", \"man\").blend()")
|
||||||
|
)
|
||||||
|
self.assertEqual(Conjunction(
|
||||||
|
[Blend([FlattenedPrompt([('mountain', 1.0)]), FlattenedPrompt([('man', 1.0)])], [1.0, 1.0])]),
|
||||||
|
parse_prompt("(mountain, man).blend()")
|
||||||
|
)
|
||||||
|
self.assertEqual(Conjunction(
|
||||||
|
[Blend([FlattenedPrompt([('mountain', 1.0)]), FlattenedPrompt([('man', 1.0)])], [1.0, 1.0])]),
|
||||||
|
parse_prompt("((mountain), (man)).blend()")
|
||||||
|
)
|
||||||
|
self.assertEqual(Conjunction(
|
||||||
|
[Blend([FlattenedPrompt([('mountain', 1.0)]), FlattenedPrompt([('tall man', 1.0)])], [1.0, 1.0])]),
|
||||||
|
parse_prompt("((mountain), (tall man)).blend()")
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.assertRaises(PromptParser.ParsingException):
|
||||||
|
print(parse_prompt("((mountain), \"cat.swap(dog)\").blend()"))
|
||||||
|
|
||||||
self.assertEqual(Conjunction(
|
self.assertEqual(Conjunction(
|
||||||
[Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('fire flames', 1.0)])], [0.7, 0.3])]),
|
[Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('fire flames', 1.0)])], [0.7, 0.3])]),
|
||||||
parse_prompt("(\"fire\", \"fire flames\").blend(0.7, 0.3)")
|
parse_prompt("(\"fire\", \"fire flames\").blend(0.7, 0.3)")
|
||||||
@ -167,9 +195,19 @@ class PromptParserTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
Conjunction([Blend([FlattenedPrompt([('mountain , man , hairy', 1)]),
|
Conjunction([Blend([FlattenedPrompt([('mountain , man , hairy', 1)]),
|
||||||
FlattenedPrompt([('face, teeth,', 1), ('eyes', 0.9*0.9)])], weights=[1.0,-1.0])]),
|
FlattenedPrompt([('face , teeth ,', 1), ('eyes', 0.9*0.9)])], weights=[1.0,-1.0], normalize_weights=True)]),
|
||||||
parse_prompt('("mountain, man, hairy", "face, teeth, eyes--").blend(1,-1)')
|
parse_prompt('("mountain, man, hairy", "face, teeth, eyes--").blend(1,-1)')
|
||||||
)
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
Conjunction([Blend([FlattenedPrompt([('mountain , man , hairy', 1)]),
|
||||||
|
FlattenedPrompt([('face , teeth ,', 1), ('eyes', 0.9 * 0.9)])], weights=[1.0, -1.0], normalize_weights=False)]),
|
||||||
|
parse_prompt('("mountain, man, hairy", "face, teeth, eyes--").blend(1,-1,no_normalize)')
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.assertRaises(PromptParser.ParsingException):
|
||||||
|
parse_prompt("(\"fire\", \"fire flames\").blend(0.7, 0.3, 0.1)")
|
||||||
|
with self.assertRaises(PromptParser.ParsingException):
|
||||||
|
parse_prompt("(\"fire\", \"fire flames\").blend(0.7)")
|
||||||
|
|
||||||
|
|
||||||
def test_nested(self):
|
def test_nested(self):
|
||||||
@ -182,6 +220,9 @@ class PromptParserTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
def test_cross_attention_control(self):
|
def test_cross_attention_control(self):
|
||||||
|
|
||||||
|
self.assertEqual(Conjunction([FlattenedPrompt([CrossAttentionControlSubstitute([Fragment('sun')], [Fragment('moon')])])]),
|
||||||
|
parse_prompt("sun.swap(moon)"))
|
||||||
|
|
||||||
self.assertEqual(Conjunction([
|
self.assertEqual(Conjunction([
|
||||||
FlattenedPrompt([Fragment('a', 1),
|
FlattenedPrompt([Fragment('a', 1),
|
||||||
CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)]),
|
CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)]),
|
||||||
@ -259,6 +300,12 @@ class PromptParserTestCase(unittest.TestCase):
|
|||||||
Fragment(',', 1), Fragment('fire', 2.0)])])
|
Fragment(',', 1), Fragment('fire', 2.0)])])
|
||||||
self.assertEqual(flames_to_trees_fire, parse_prompt('"(fire (flames)0.5)0.5".swap("(trees)0.7 houses"), (fire)2.0'))
|
self.assertEqual(flames_to_trees_fire, parse_prompt('"(fire (flames)0.5)0.5".swap("(trees)0.7 houses"), (fire)2.0'))
|
||||||
|
|
||||||
|
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1),
|
||||||
|
CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]),
|
||||||
|
Fragment('eating a', 1),
|
||||||
|
CrossAttentionControlSubstitute([Fragment('hotdog',1)], [Fragment('hotdog', pow(1.1,4))])
|
||||||
|
])]),
|
||||||
|
parse_prompt("a cat.swap(dog) eating a hotdog.swap(hotdog++++)"))
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1),
|
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1),
|
||||||
CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]),
|
CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]),
|
||||||
Fragment('eating a', 1),
|
Fragment('eating a', 1),
|
||||||
@ -433,6 +480,15 @@ class PromptParserTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
def test_single(self):
|
def test_single(self):
|
||||||
|
self.assertEqual(Conjunction([FlattenedPrompt([("mountain man", 1.0)]),
|
||||||
|
FlattenedPrompt([("a person with a hat", 1.0),
|
||||||
|
("riding a", 1.1*1.1),
|
||||||
|
CrossAttentionControlSubstitute(
|
||||||
|
[Fragment("bicycle", pow(1.1,2))],
|
||||||
|
[Fragment("skateboard", pow(1.1,2))])
|
||||||
|
])
|
||||||
|
], weights=[0.5, 0.5]),
|
||||||
|
parse_prompt("(\"mountain man\", \"a person with a hat (riding a bicycle.swap(skateboard))++\").and(0.5, 0.5)"))
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user