mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
ok now we're cooking
This commit is contained in:
parent
da75876639
commit
2e0b1c4c8b
@ -1,13 +1,17 @@
|
|||||||
import string
|
import string
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import pyparsing
|
|
||||||
import pyparsing as pp
|
import pyparsing as pp
|
||||||
from pyparsing import original_text_for
|
|
||||||
|
|
||||||
|
|
||||||
class Prompt():
|
class Prompt():
|
||||||
|
"""
|
||||||
|
Mid-level structure for storing the tree-like result of parsing a prompt. A Prompt may not represent the whole of
|
||||||
|
the singular user-defined "prompt string" (although it can) - for example, if the user specifies a Blend, the objects
|
||||||
|
that are to be blended together are stored individuall as Prompt objects.
|
||||||
|
|
||||||
|
Nesting makes this object not suitable for directly tokenizing; instead call flatten() on the containing Conjunction
|
||||||
|
to produce a FlattenedPrompt.
|
||||||
|
"""
|
||||||
def __init__(self, parts: list):
|
def __init__(self, parts: list):
|
||||||
for c in parts:
|
for c in parts:
|
||||||
if type(c) is not Attention and not issubclass(type(c), BaseFragment) and type(c) is not pp.ParseResults:
|
if type(c) is not Attention and not issubclass(type(c), BaseFragment) and type(c) is not pp.ParseResults:
|
||||||
@ -22,13 +26,16 @@ class BaseFragment:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
class FlattenedPrompt():
|
class FlattenedPrompt():
|
||||||
|
"""
|
||||||
|
A Prompt that has been passed through flatten(). Its children can be readily tokenized.
|
||||||
|
"""
|
||||||
def __init__(self, parts: list=[]):
|
def __init__(self, parts: list=[]):
|
||||||
# verify type correctness
|
|
||||||
self.children = []
|
self.children = []
|
||||||
for part in parts:
|
for part in parts:
|
||||||
self.append(part)
|
self.append(part)
|
||||||
|
|
||||||
def append(self, fragment: Union[list, BaseFragment, tuple]):
|
def append(self, fragment: Union[list, BaseFragment, tuple]):
|
||||||
|
# verify type correctness
|
||||||
if type(fragment) is list:
|
if type(fragment) is list:
|
||||||
for x in fragment:
|
for x in fragment:
|
||||||
self.append(x)
|
self.append(x)
|
||||||
@ -49,8 +56,11 @@ class FlattenedPrompt():
|
|||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return type(other) is FlattenedPrompt and other.children == self.children
|
return type(other) is FlattenedPrompt and other.children == self.children
|
||||||
|
|
||||||
# abstract base class for Fragments
|
|
||||||
class Fragment(BaseFragment):
|
class Fragment(BaseFragment):
|
||||||
|
"""
|
||||||
|
A Fragment is a chunk of plain text and an optional weight. The text should be passed as-is to the CLIP tokenizer.
|
||||||
|
"""
|
||||||
def __init__(self, text: str, weight: float=1):
|
def __init__(self, text: str, weight: float=1):
|
||||||
assert(type(text) is str)
|
assert(type(text) is str)
|
||||||
if '\\"' in text or '\\(' in text or '\\)' in text:
|
if '\\"' in text or '\\(' in text or '\\)' in text:
|
||||||
@ -67,6 +77,12 @@ class Fragment(BaseFragment):
|
|||||||
and other.weight == self.weight
|
and other.weight == self.weight
|
||||||
|
|
||||||
class Attention():
|
class Attention():
|
||||||
|
"""
|
||||||
|
Nestable weight control for fragments. Each object in the children array may in turn be an Attention object;
|
||||||
|
weights should be considered to accumulate as the tree is traversed to deeper levels of nesting.
|
||||||
|
|
||||||
|
Do not traverse directly; instead obtain a FlattenedPrompt by calling Flatten() on a top-level Conjunction object.
|
||||||
|
"""
|
||||||
def __init__(self, weight: float, children: list):
|
def __init__(self, weight: float, children: list):
|
||||||
self.weight = weight
|
self.weight = weight
|
||||||
self.children = children
|
self.children = children
|
||||||
@ -81,7 +97,28 @@ class CrossAttentionControlledFragment(BaseFragment):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
class CrossAttentionControlSubstitute(CrossAttentionControlledFragment):
|
class CrossAttentionControlSubstitute(CrossAttentionControlledFragment):
|
||||||
def __init__(self, original: Fragment, edited: Fragment):
|
"""
|
||||||
|
A Cross-Attention Controlled ('prompt2prompt') fragment, for use inside a Prompt, Attention, or FlattenedPrompt.
|
||||||
|
Representing an "original" word sequence that supplies feature vectors for an initial diffusion operation, and an
|
||||||
|
"edited" word sequence, to which the attention maps produced by the "original" word sequence are applied. Intuitively,
|
||||||
|
the result should be an "edited" image that looks like the "original" image with concepts swapped.
|
||||||
|
|
||||||
|
eg "a cat sitting on a car" (original) -> "a smiling dog sitting on a car" (edited): the edited image should look
|
||||||
|
almost exactly the same as the original, but with a smiling dog rendered in place of the cat. The
|
||||||
|
CrossAttentionControlSubstitute object representing this swap may be confined to the tokens being swapped:
|
||||||
|
CrossAttentionControlSubstitute(original=[Fragment('cat')], edited=[Fragment('dog')])
|
||||||
|
or it may represent a larger portion of the token sequence:
|
||||||
|
CrossAttentionControlSubstitute(original=[Fragment('a cat sitting on a car')],
|
||||||
|
edited=[Fragment('a smiling dog sitting on a car')])
|
||||||
|
|
||||||
|
In either case expect it to be embedded in a Prompt or FlattenedPrompt:
|
||||||
|
FlattenedPrompt([
|
||||||
|
Fragment('a'),
|
||||||
|
CrossAttentionControlSubstitute(original=[Fragment('cat')], edited=[Fragment('dog')]),
|
||||||
|
Fragment('sitting on a car')
|
||||||
|
])
|
||||||
|
"""
|
||||||
|
def __init__(self, original: Union[Fragment, list], edited: Union[Fragment, list]):
|
||||||
self.original = original
|
self.original = original
|
||||||
self.edited = edited
|
self.edited = edited
|
||||||
|
|
||||||
@ -92,6 +129,7 @@ class CrossAttentionControlSubstitute(CrossAttentionControlledFragment):
|
|||||||
and other.original == self.original \
|
and other.original == self.original \
|
||||||
and other.edited == self.edited
|
and other.edited == self.edited
|
||||||
|
|
||||||
|
|
||||||
class CrossAttentionControlAppend(CrossAttentionControlledFragment):
|
class CrossAttentionControlAppend(CrossAttentionControlledFragment):
|
||||||
def __init__(self, fragment: Fragment):
|
def __init__(self, fragment: Fragment):
|
||||||
self.fragment = fragment
|
self.fragment = fragment
|
||||||
@ -104,6 +142,10 @@ class CrossAttentionControlAppend(CrossAttentionControlledFragment):
|
|||||||
|
|
||||||
|
|
||||||
class Conjunction():
|
class Conjunction():
|
||||||
|
"""
|
||||||
|
Storage for one or more Prompts or Blends, each of which is to be separately diffused and then the results merged
|
||||||
|
by weighted sum in latent space.
|
||||||
|
"""
|
||||||
def __init__(self, prompts: list, weights: list = None):
|
def __init__(self, prompts: list, weights: list = None):
|
||||||
# force everything to be a Prompt
|
# force everything to be a Prompt
|
||||||
#print("making conjunction with", parts)
|
#print("making conjunction with", parts)
|
||||||
@ -125,6 +167,11 @@ class Conjunction():
|
|||||||
|
|
||||||
|
|
||||||
class Blend():
|
class Blend():
|
||||||
|
"""
|
||||||
|
Stores a Blend of multiple Prompts. To apply, build feature vectors for each of the child Prompts and then perform a
|
||||||
|
weighted blend of the feature vectors to produce a single feature vector that is effectively a lerp between the
|
||||||
|
Prompts.
|
||||||
|
"""
|
||||||
def __init__(self, prompts: list, weights: list[float], normalize_weights: bool=True):
|
def __init__(self, prompts: list, weights: list[float], normalize_weights: bool=True):
|
||||||
#print("making Blend with prompts", prompts, "and weights", weights)
|
#print("making Blend with prompts", prompts, "and weights", weights)
|
||||||
if len(prompts) != len(weights):
|
if len(prompts) != len(weights):
|
||||||
@ -152,16 +199,11 @@ class PromptParser():
|
|||||||
|
|
||||||
def __init__(self, attention_plus_base=1.1, attention_minus_base=0.9):
|
def __init__(self, attention_plus_base=1.1, attention_minus_base=0.9):
|
||||||
|
|
||||||
self.attention_plus_base = attention_plus_base
|
self.root = build_parser_syntax(attention_plus_base, attention_minus_base)
|
||||||
self.attention_minus_base = attention_minus_base
|
|
||||||
|
|
||||||
self.root = self.build_parser_logic()
|
|
||||||
|
|
||||||
|
|
||||||
def parse(self, prompt: str) -> Conjunction:
|
def parse(self, prompt: str) -> Conjunction:
|
||||||
'''
|
'''
|
||||||
This parser is *very* forgiving. If it cannot parse syntax, it will return strings as-is to be passed on to the
|
|
||||||
diffusion.
|
|
||||||
:param prompt: The prompt string to parse
|
:param prompt: The prompt string to parse
|
||||||
:return: a Conjunction representing the parsed results.
|
:return: a Conjunction representing the parsed results.
|
||||||
'''
|
'''
|
||||||
@ -177,7 +219,16 @@ class PromptParser():
|
|||||||
|
|
||||||
return self.flatten(root[0])
|
return self.flatten(root[0])
|
||||||
|
|
||||||
def flatten(self, root: Conjunction):
|
|
||||||
|
def flatten(self, root: Conjunction) -> Conjunction:
|
||||||
|
"""
|
||||||
|
Flattening a Conjunction traverses all of the nested tree-like structures in each of its Prompts or Blends,
|
||||||
|
producing from each of these walks a linear sequence of Fragment or CrossAttentionControlSubstitute objects
|
||||||
|
that can be readily tokenized without the need to walk a complex tree structure.
|
||||||
|
|
||||||
|
:param root: The Conjunction to flatten.
|
||||||
|
:return: A Conjunction containing the result of flattening each of the prompts in the passed-in root.
|
||||||
|
"""
|
||||||
|
|
||||||
#print("flattening", root)
|
#print("flattening", root)
|
||||||
|
|
||||||
@ -242,25 +293,31 @@ class PromptParser():
|
|||||||
flattened_parts = []
|
flattened_parts = []
|
||||||
for part in root.prompts:
|
for part in root.prompts:
|
||||||
flattened_parts += flatten_internal(part, 1.0, [], ' C| ')
|
flattened_parts += flatten_internal(part, 1.0, [], ' C| ')
|
||||||
|
|
||||||
|
#print("flattened to", flattened_parts)
|
||||||
|
|
||||||
weights = root.weights
|
weights = root.weights
|
||||||
return Conjunction(flattened_parts, weights)
|
return Conjunction(flattened_parts, weights)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def build_parser_logic(self):
|
def build_parser_syntax(attention_plus_base: float, attention_minus_base: float):
|
||||||
|
|
||||||
lparen = pp.Literal("(").suppress()
|
lparen = pp.Literal("(").suppress()
|
||||||
rparen = pp.Literal(")").suppress()
|
rparen = pp.Literal(")").suppress()
|
||||||
quotes = pp.Literal('"').suppress()
|
quotes = pp.Literal('"').suppress()
|
||||||
|
|
||||||
# accepts int or float notation, always maps to float
|
# accepts int or float notation, always maps to float
|
||||||
number = pyparsing.pyparsing_common.real | pp.Word(pp.nums).set_parse_action(pp.token_map(float))
|
number = pp.pyparsing_common.real | pp.Word(pp.nums).set_parse_action(pp.token_map(float))
|
||||||
SPACE_CHARS = string.whitespace
|
|
||||||
greedy_word = pp.Word(pp.printables, exclude_chars=string.whitespace).set_name('greedy_word')
|
greedy_word = pp.Word(pp.printables, exclude_chars=string.whitespace).set_name('greedy_word')
|
||||||
|
|
||||||
attention = pp.Forward()
|
attention = pp.Forward()
|
||||||
|
quoted_fragment = pp.Forward()
|
||||||
|
parenthesized_fragment = pp.Forward()
|
||||||
|
cross_attention_substitute = pp.Forward()
|
||||||
|
prompt_part = pp.Forward()
|
||||||
|
|
||||||
def make_fragment(x):
|
def make_text_fragment(x):
|
||||||
#print("### making fragment for", x)
|
#print("### making fragment for", x)
|
||||||
if type(x) is str:
|
if type(x) is str:
|
||||||
return Fragment(x)
|
return Fragment(x)
|
||||||
@ -270,47 +327,51 @@ class PromptParser():
|
|||||||
else:
|
else:
|
||||||
raise PromptParser.ParsingException("Cannot make fragment from " + str(x))
|
raise PromptParser.ParsingException("Cannot make fragment from " + str(x))
|
||||||
|
|
||||||
quoted_fragment = pp.Forward()
|
def parse_fragment_str(x, in_quotes: bool=False, in_parens: bool=False):
|
||||||
parenthesized_fragment = pp.Forward()
|
fragment_string = x[0]
|
||||||
|
print(f"parsing fragment string \"{fragment_string}\"")
|
||||||
def parse_fragment_str(x):
|
if len(fragment_string.strip()) == 0:
|
||||||
#print("parsing fragment string", x)
|
|
||||||
if len(x[0].strip()) == 0:
|
|
||||||
return Fragment('')
|
return Fragment('')
|
||||||
fragment_parser = pp.Group(pp.OneOrMore(attention | (greedy_word.set_parse_action(make_fragment))))
|
|
||||||
fragment_parser.set_name('word_or_attention')
|
if in_quotes:
|
||||||
result = fragment_parser.parse_string(x[0])
|
# escape unescaped quotes
|
||||||
|
fragment_string = fragment_string.replace('"', '\\"')
|
||||||
|
|
||||||
|
#fragment_parser = pp.Group(pp.OneOrMore(attention | cross_attention_substitute | (greedy_word.set_parse_action(make_text_fragment))))
|
||||||
|
result = pp.Group(pp.MatchFirst([
|
||||||
|
pp.OneOrMore(prompt_part | quoted_fragment),
|
||||||
|
pp.Empty().set_parse_action(make_text_fragment) + pp.StringEnd()
|
||||||
|
])).set_name('rr').set_debug(False).parse_string(fragment_string)
|
||||||
#result = (pp.OneOrMore(attention | unquoted_word) + pp.StringEnd()).parse_string(x[0])
|
#result = (pp.OneOrMore(attention | unquoted_word) + pp.StringEnd()).parse_string(x[0])
|
||||||
#print("parsed to", result)
|
#print("parsed to", result)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
quoted_fragment << pp.QuotedString(quote_char='"', esc_char='\\')
|
quoted_fragment << pp.QuotedString(quote_char='"', esc_char=None, esc_quote='\\"')
|
||||||
quoted_fragment.set_parse_action(parse_fragment_str).set_name('quoted_fragment')
|
quoted_fragment.set_parse_action(lambda x: parse_fragment_str(x, in_quotes=True)).set_name('quoted_fragment')
|
||||||
|
|
||||||
self_unescaping_escaped_quote = pp.Literal('\\"').set_parse_action(lambda x: '"')
|
escaped_quote = pp.Literal('\\"')#.set_parse_action(lambda x: '"')
|
||||||
self_unescaping_escaped_lparen = pp.Literal('\\(').set_parse_action(lambda x: '(')
|
escaped_lparen = pp.Literal('\\(')#.set_parse_action(lambda x: '(')
|
||||||
self_unescaping_escaped_rparen = pp.Literal('\\)').set_parse_action(lambda x: ')')
|
escaped_rparen = pp.Literal('\\)')#.set_parse_action(lambda x: ')')
|
||||||
|
escaped_backslash = pp.Literal('\\\\')#.set_parse_action(lambda x: '"')
|
||||||
|
|
||||||
def not_ends_with_swap(x):
|
def not_ends_with_swap(x):
|
||||||
#print("trying to match:", x)
|
#print("trying to match:", x)
|
||||||
return not x[0].endswith('.swap')
|
return not x[0].endswith('.swap')
|
||||||
|
|
||||||
unquoted_fragment = pp.Combine(pp.OneOrMore(
|
unquoted_fragment = pp.Combine(pp.OneOrMore(
|
||||||
self_unescaping_escaped_rparen | self_unescaping_escaped_lparen | self_unescaping_escaped_quote |
|
escaped_rparen | escaped_lparen | escaped_quote | escaped_backslash |
|
||||||
pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"()')))
|
pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"()')))
|
||||||
unquoted_fragment.set_parse_action(make_fragment).set_name('unquoted_fragment').set_debug(True)
|
unquoted_fragment.set_parse_action(make_text_fragment).set_name('unquoted_fragment').set_debug(False)
|
||||||
#print(unquoted_fragment.parse_string("cat.swap(dog)"))
|
#print(unquoted_fragment.parse_string("cat.swap(dog)"))
|
||||||
|
|
||||||
parenthesized_fragment << pp.MatchFirst([
|
parenthesized_fragment << pp.Or([
|
||||||
(lparen + quoted_fragment.copy().set_parse_action(parse_fragment_str).set_debug(False) + rparen).set_name('-quoted_paren_internal').set_debug(False),
|
(lparen + quoted_fragment.copy().set_parse_action(lambda x: parse_fragment_str(x, in_quotes=True)).set_debug(False) + rparen).set_name('-quoted_paren_internal').set_debug(False),
|
||||||
(lparen + rparen).set_parse_action(lambda x: make_fragment('')).set_name('-()').set_debug(False),
|
(lparen + rparen).set_parse_action(lambda x: make_text_fragment('')).set_name('-()').set_debug(False),
|
||||||
(lparen + pp.Combine(pp.OneOrMore(
|
(lparen + pp.Combine(pp.OneOrMore(
|
||||||
pp.Literal('\\"').set_debug(False).set_parse_action(lambda x: '"') |
|
escaped_quote | escaped_lparen | escaped_rparen | escaped_backslash |
|
||||||
pp.Literal('\\(').set_debug(False).set_parse_action(lambda x: '(') |
|
|
||||||
pp.Literal('\\)').set_debug(False).set_parse_action(lambda x: ')') |
|
|
||||||
pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"()') |
|
pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"()') |
|
||||||
pp.Word(string.whitespace)
|
pp.Word(string.whitespace)
|
||||||
)).set_name('--combined').set_parse_action(parse_fragment_str).set_debug(False) + rparen)]).set_name('-unquoted_paren_internal').set_debug(False)
|
)).set_name('--combined').set_parse_action(lambda x: parse_fragment_str(x, in_parens=True)).set_debug(False) + rparen)]).set_name('-unquoted_paren_internal').set_debug(False)
|
||||||
parenthesized_fragment.set_name('parenthesized_fragment').set_debug(False)
|
parenthesized_fragment.set_name('parenthesized_fragment').set_debug(False)
|
||||||
|
|
||||||
debug_attention = False
|
debug_attention = False
|
||||||
@ -324,11 +385,14 @@ class PromptParser():
|
|||||||
pp.Word(pp.printables, exclude_chars=string.whitespace + '\\()"')
|
pp.Word(pp.printables, exclude_chars=string.whitespace + '\\()"')
|
||||||
)).set_name('word_inside_attention')
|
)).set_name('word_inside_attention')
|
||||||
attention_with_parens = pp.Forward()
|
attention_with_parens = pp.Forward()
|
||||||
attention_with_parens_delimited_list = pp.delimited_list(pp.Or([
|
|
||||||
|
attention_with_parens_delimited_list = pp.OneOrMore(pp.Or([
|
||||||
quoted_fragment.copy().set_debug(debug_attention),
|
quoted_fragment.copy().set_debug(debug_attention),
|
||||||
attention.copy().set_debug(debug_attention),
|
attention.copy().set_debug(debug_attention),
|
||||||
word_inside_attention.set_debug(debug_attention)]).set_name('delim_inner').set_debug(debug_attention),
|
cross_attention_substitute,
|
||||||
delim=string.whitespace)
|
word_inside_attention.set_debug(debug_attention)
|
||||||
|
#pp.White()
|
||||||
|
]).set_name('delim_inner').set_debug(debug_attention))
|
||||||
# have to disable ignore_expr here to prevent pyparsing from stripping off quote marks
|
# have to disable ignore_expr here to prevent pyparsing from stripping off quote marks
|
||||||
attention_with_parens_body = pp.nested_expr(content=attention_with_parens_delimited_list,
|
attention_with_parens_body = pp.nested_expr(content=attention_with_parens_delimited_list,
|
||||||
ignore_expr=None#((pp.Literal("\\(") | pp.Literal('\\)')))
|
ignore_expr=None#((pp.Literal("\\(") | pp.Literal('\\)')))
|
||||||
@ -341,6 +405,7 @@ class PromptParser():
|
|||||||
attention_without_parens.set_name('attention_without_parens').set_debug(debug_attention)
|
attention_without_parens.set_name('attention_without_parens').set_debug(debug_attention)
|
||||||
|
|
||||||
attention << (attention_with_parens | attention_without_parens)
|
attention << (attention_with_parens | attention_without_parens)
|
||||||
|
attention.set_name('attention')
|
||||||
|
|
||||||
def make_attention(x):
|
def make_attention(x):
|
||||||
#print("making Attention from", x)
|
#print("making Attention from", x)
|
||||||
@ -350,7 +415,7 @@ class PromptParser():
|
|||||||
weight = float(x[0])
|
weight = float(x[0])
|
||||||
# +(str) or -(str) or +str or -str
|
# +(str) or -(str) or +str or -str
|
||||||
elif type(x[0]) is str:
|
elif type(x[0]) is str:
|
||||||
base = self.attention_plus_base if x[0][0] == '+' else self.attention_minus_base
|
base = attention_plus_base if x[0][0] == '+' else attention_minus_base
|
||||||
weight = pow(base, len(x[0]))
|
weight = pow(base, len(x[0]))
|
||||||
if type(x[1]) is list or type(x[1]) is pp.ParseResults:
|
if type(x[1]) is list or type(x[1]) is pp.ParseResults:
|
||||||
return Attention(weight=weight, children=[(Fragment(x) if type(x) is str else x) for x in x[1]])
|
return Attention(weight=weight, children=[(Fragment(x) if type(x) is str else x) for x in x[1]])
|
||||||
@ -375,10 +440,10 @@ class PromptParser():
|
|||||||
original_fragment = pp.Or([empty_string.set_debug(debug_cross_attention_control),
|
original_fragment = pp.Or([empty_string.set_debug(debug_cross_attention_control),
|
||||||
quoted_fragment.set_debug(debug_cross_attention_control),
|
quoted_fragment.set_debug(debug_cross_attention_control),
|
||||||
parenthesized_fragment.set_debug(debug_cross_attention_control),
|
parenthesized_fragment.set_debug(debug_cross_attention_control),
|
||||||
pp.Word(pp.printables, exclude_chars=string.whitespace + '.').set_parse_action(make_fragment) + pp.FollowedBy(".swap")
|
pp.Word(pp.printables, exclude_chars=string.whitespace + '.').set_parse_action(make_text_fragment) + pp.FollowedBy(".swap")
|
||||||
])
|
])
|
||||||
edited_fragment = parenthesized_fragment
|
edited_fragment = parenthesized_fragment
|
||||||
cross_attention_substitute = original_fragment + pp.Literal(".swap").suppress() + edited_fragment
|
cross_attention_substitute << original_fragment + pp.Literal(".swap").suppress() + edited_fragment
|
||||||
|
|
||||||
original_fragment.set_name('original_fragment').set_debug(debug_cross_attention_control)
|
original_fragment.set_name('original_fragment').set_debug(debug_cross_attention_control)
|
||||||
edited_fragment.set_name('edited_fragment').set_debug(debug_cross_attention_control)
|
edited_fragment.set_name('edited_fragment').set_debug(debug_cross_attention_control)
|
||||||
@ -392,13 +457,11 @@ class PromptParser():
|
|||||||
cross_attention_substitute.set_parse_action(make_cross_attention_substitute)
|
cross_attention_substitute.set_parse_action(make_cross_attention_substitute)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# simple fragments of text
|
# simple fragments of text
|
||||||
# use Or to match the longest
|
# use Or to match the longest
|
||||||
prompt_part = pp.Or([
|
prompt_part << pp.MatchFirst([
|
||||||
cross_attention_substitute,
|
cross_attention_substitute,
|
||||||
attention,
|
attention,
|
||||||
quoted_fragment,
|
|
||||||
unquoted_fragment,
|
unquoted_fragment,
|
||||||
lparen + unquoted_fragment + rparen # matches case where user has +(term) and just deletes the +
|
lparen + unquoted_fragment + rparen # matches case where user has +(term) and just deletes the +
|
||||||
])
|
])
|
||||||
@ -410,7 +473,7 @@ class PromptParser():
|
|||||||
(quotes + pp.ZeroOrMore(pp.Word(string.whitespace)) + quotes)).set_debug(False).set_name('empty')
|
(quotes + pp.ZeroOrMore(pp.Word(string.whitespace)) + quotes)).set_debug(False).set_name('empty')
|
||||||
|
|
||||||
# root prompt definition
|
# root prompt definition
|
||||||
prompt = ((pp.OneOrMore(prompt_part) | empty) + pp.StringEnd()) \
|
prompt = ((pp.OneOrMore(prompt_part | quoted_fragment) | empty) + pp.StringEnd()) \
|
||||||
.set_parse_action(lambda x: Prompt(x))
|
.set_parse_action(lambda x: Prompt(x))
|
||||||
|
|
||||||
|
|
||||||
|
@ -77,6 +77,15 @@ class PromptParserTestCase(unittest.TestCase):
|
|||||||
def test_complex_conjunction(self):
|
def test_complex_conjunction(self):
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([("mountain man", 1.0)]), FlattenedPrompt([("a person with a hat", 1.0), ("riding a bicycle", pow(1.1,2))])], weights=[0.5, 0.5]),
|
self.assertEqual(Conjunction([FlattenedPrompt([("mountain man", 1.0)]), FlattenedPrompt([("a person with a hat", 1.0), ("riding a bicycle", pow(1.1,2))])], weights=[0.5, 0.5]),
|
||||||
parse_prompt("(\"mountain man\", \"a person with a hat ++(riding a bicycle)\").and(0.5, 0.5)"))
|
parse_prompt("(\"mountain man\", \"a person with a hat ++(riding a bicycle)\").and(0.5, 0.5)"))
|
||||||
|
self.assertEqual(Conjunction([FlattenedPrompt([("mountain man", 1.0)]),
|
||||||
|
FlattenedPrompt([("a person with a hat", 1.0),
|
||||||
|
("riding a", 1.1*1.1),
|
||||||
|
CrossAttentionControlSubstitute(
|
||||||
|
[Fragment("bicycle", pow(1.1,2))],
|
||||||
|
[Fragment("skateboard", pow(1.1,2))])
|
||||||
|
])
|
||||||
|
], weights=[0.5, 0.5]),
|
||||||
|
parse_prompt("(\"mountain man\", \"a person with a hat ++(riding a bicycle.swap(skateboard))\").and(0.5, 0.5)"))
|
||||||
|
|
||||||
def test_badly_formed(self):
|
def test_badly_formed(self):
|
||||||
def make_untouched_prompt(prompt):
|
def make_untouched_prompt(prompt):
|
||||||
@ -309,8 +318,7 @@ class PromptParserTestCase(unittest.TestCase):
|
|||||||
parse_prompt('mountain (\(\(\().swap(m\(on\)\)key)'))
|
parse_prompt('mountain (\(\(\().swap(m\(on\)\)key)'))
|
||||||
|
|
||||||
def test_single(self):
|
def test_single(self):
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('"man', 1)], [Fragment('monkey', 1)])])]),
|
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mou)ntain, \"man (wit(h a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy +("mou\)ntain, \\\"man\" \(wit\(h a 2.0(beard))'))
|
||||||
parse_prompt('mountain (\\"man).swap("monkey")'))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
Loading…
x
Reference in New Issue
Block a user