ok now we're cooking

This commit is contained in:
Damian at mba 2022-10-21 03:29:50 +02:00
parent da75876639
commit 2e0b1c4c8b
2 changed files with 273 additions and 202 deletions

View File

@ -1,13 +1,17 @@
import string import string
from typing import Union from typing import Union
import pyparsing
import pyparsing as pp import pyparsing as pp
from pyparsing import original_text_for
class Prompt(): class Prompt():
"""
Mid-level structure for storing the tree-like result of parsing a prompt. A Prompt may not represent the whole of
the singular user-defined "prompt string" (although it can) - for example, if the user specifies a Blend, the objects
that are to be blended together are stored individuall as Prompt objects.
Nesting makes this object not suitable for directly tokenizing; instead call flatten() on the containing Conjunction
to produce a FlattenedPrompt.
"""
def __init__(self, parts: list): def __init__(self, parts: list):
for c in parts: for c in parts:
if type(c) is not Attention and not issubclass(type(c), BaseFragment) and type(c) is not pp.ParseResults: if type(c) is not Attention and not issubclass(type(c), BaseFragment) and type(c) is not pp.ParseResults:
@ -22,13 +26,16 @@ class BaseFragment:
pass pass
class FlattenedPrompt(): class FlattenedPrompt():
"""
A Prompt that has been passed through flatten(). Its children can be readily tokenized.
"""
def __init__(self, parts: list=[]): def __init__(self, parts: list=[]):
# verify type correctness
self.children = [] self.children = []
for part in parts: for part in parts:
self.append(part) self.append(part)
def append(self, fragment: Union[list, BaseFragment, tuple]): def append(self, fragment: Union[list, BaseFragment, tuple]):
# verify type correctness
if type(fragment) is list: if type(fragment) is list:
for x in fragment: for x in fragment:
self.append(x) self.append(x)
@ -49,8 +56,11 @@ class FlattenedPrompt():
def __eq__(self, other): def __eq__(self, other):
return type(other) is FlattenedPrompt and other.children == self.children return type(other) is FlattenedPrompt and other.children == self.children
# abstract base class for Fragments
class Fragment(BaseFragment): class Fragment(BaseFragment):
"""
A Fragment is a chunk of plain text and an optional weight. The text should be passed as-is to the CLIP tokenizer.
"""
def __init__(self, text: str, weight: float=1): def __init__(self, text: str, weight: float=1):
assert(type(text) is str) assert(type(text) is str)
if '\\"' in text or '\\(' in text or '\\)' in text: if '\\"' in text or '\\(' in text or '\\)' in text:
@ -67,6 +77,12 @@ class Fragment(BaseFragment):
and other.weight == self.weight and other.weight == self.weight
class Attention(): class Attention():
"""
Nestable weight control for fragments. Each object in the children array may in turn be an Attention object;
weights should be considered to accumulate as the tree is traversed to deeper levels of nesting.
Do not traverse directly; instead obtain a FlattenedPrompt by calling Flatten() on a top-level Conjunction object.
"""
def __init__(self, weight: float, children: list): def __init__(self, weight: float, children: list):
self.weight = weight self.weight = weight
self.children = children self.children = children
@ -81,7 +97,28 @@ class CrossAttentionControlledFragment(BaseFragment):
pass pass
class CrossAttentionControlSubstitute(CrossAttentionControlledFragment): class CrossAttentionControlSubstitute(CrossAttentionControlledFragment):
def __init__(self, original: Fragment, edited: Fragment): """
A Cross-Attention Controlled ('prompt2prompt') fragment, for use inside a Prompt, Attention, or FlattenedPrompt.
Representing an "original" word sequence that supplies feature vectors for an initial diffusion operation, and an
"edited" word sequence, to which the attention maps produced by the "original" word sequence are applied. Intuitively,
the result should be an "edited" image that looks like the "original" image with concepts swapped.
eg "a cat sitting on a car" (original) -> "a smiling dog sitting on a car" (edited): the edited image should look
almost exactly the same as the original, but with a smiling dog rendered in place of the cat. The
CrossAttentionControlSubstitute object representing this swap may be confined to the tokens being swapped:
CrossAttentionControlSubstitute(original=[Fragment('cat')], edited=[Fragment('dog')])
or it may represent a larger portion of the token sequence:
CrossAttentionControlSubstitute(original=[Fragment('a cat sitting on a car')],
edited=[Fragment('a smiling dog sitting on a car')])
In either case expect it to be embedded in a Prompt or FlattenedPrompt:
FlattenedPrompt([
Fragment('a'),
CrossAttentionControlSubstitute(original=[Fragment('cat')], edited=[Fragment('dog')]),
Fragment('sitting on a car')
])
"""
def __init__(self, original: Union[Fragment, list], edited: Union[Fragment, list]):
self.original = original self.original = original
self.edited = edited self.edited = edited
@ -92,6 +129,7 @@ class CrossAttentionControlSubstitute(CrossAttentionControlledFragment):
and other.original == self.original \ and other.original == self.original \
and other.edited == self.edited and other.edited == self.edited
class CrossAttentionControlAppend(CrossAttentionControlledFragment): class CrossAttentionControlAppend(CrossAttentionControlledFragment):
def __init__(self, fragment: Fragment): def __init__(self, fragment: Fragment):
self.fragment = fragment self.fragment = fragment
@ -104,6 +142,10 @@ class CrossAttentionControlAppend(CrossAttentionControlledFragment):
class Conjunction(): class Conjunction():
"""
Storage for one or more Prompts or Blends, each of which is to be separately diffused and then the results merged
by weighted sum in latent space.
"""
def __init__(self, prompts: list, weights: list = None): def __init__(self, prompts: list, weights: list = None):
# force everything to be a Prompt # force everything to be a Prompt
#print("making conjunction with", parts) #print("making conjunction with", parts)
@ -125,6 +167,11 @@ class Conjunction():
class Blend(): class Blend():
"""
Stores a Blend of multiple Prompts. To apply, build feature vectors for each of the child Prompts and then perform a
weighted blend of the feature vectors to produce a single feature vector that is effectively a lerp between the
Prompts.
"""
def __init__(self, prompts: list, weights: list[float], normalize_weights: bool=True): def __init__(self, prompts: list, weights: list[float], normalize_weights: bool=True):
#print("making Blend with prompts", prompts, "and weights", weights) #print("making Blend with prompts", prompts, "and weights", weights)
if len(prompts) != len(weights): if len(prompts) != len(weights):
@ -152,16 +199,11 @@ class PromptParser():
def __init__(self, attention_plus_base=1.1, attention_minus_base=0.9): def __init__(self, attention_plus_base=1.1, attention_minus_base=0.9):
self.attention_plus_base = attention_plus_base self.root = build_parser_syntax(attention_plus_base, attention_minus_base)
self.attention_minus_base = attention_minus_base
self.root = self.build_parser_logic()
def parse(self, prompt: str) -> Conjunction: def parse(self, prompt: str) -> Conjunction:
''' '''
This parser is *very* forgiving. If it cannot parse syntax, it will return strings as-is to be passed on to the
diffusion.
:param prompt: The prompt string to parse :param prompt: The prompt string to parse
:return: a Conjunction representing the parsed results. :return: a Conjunction representing the parsed results.
''' '''
@ -177,7 +219,16 @@ class PromptParser():
return self.flatten(root[0]) return self.flatten(root[0])
def flatten(self, root: Conjunction):
def flatten(self, root: Conjunction) -> Conjunction:
"""
Flattening a Conjunction traverses all of the nested tree-like structures in each of its Prompts or Blends,
producing from each of these walks a linear sequence of Fragment or CrossAttentionControlSubstitute objects
that can be readily tokenized without the need to walk a complex tree structure.
:param root: The Conjunction to flatten.
:return: A Conjunction containing the result of flattening each of the prompts in the passed-in root.
"""
#print("flattening", root) #print("flattening", root)
@ -242,226 +293,238 @@ class PromptParser():
flattened_parts = [] flattened_parts = []
for part in root.prompts: for part in root.prompts:
flattened_parts += flatten_internal(part, 1.0, [], ' C| ') flattened_parts += flatten_internal(part, 1.0, [], ' C| ')
#print("flattened to", flattened_parts)
weights = root.weights weights = root.weights
return Conjunction(flattened_parts, weights) return Conjunction(flattened_parts, weights)
def build_parser_logic(self): def build_parser_syntax(attention_plus_base: float, attention_minus_base: float):
lparen = pp.Literal("(").suppress() lparen = pp.Literal("(").suppress()
rparen = pp.Literal(")").suppress() rparen = pp.Literal(")").suppress()
quotes = pp.Literal('"').suppress() quotes = pp.Literal('"').suppress()
# accepts int or float notation, always maps to float # accepts int or float notation, always maps to float
number = pyparsing.pyparsing_common.real | pp.Word(pp.nums).set_parse_action(pp.token_map(float)) number = pp.pyparsing_common.real | pp.Word(pp.nums).set_parse_action(pp.token_map(float))
SPACE_CHARS = string.whitespace greedy_word = pp.Word(pp.printables, exclude_chars=string.whitespace).set_name('greedy_word')
greedy_word = pp.Word(pp.printables, exclude_chars=string.whitespace).set_name('greedy_word')
attention = pp.Forward() attention = pp.Forward()
quoted_fragment = pp.Forward()
parenthesized_fragment = pp.Forward()
cross_attention_substitute = pp.Forward()
prompt_part = pp.Forward()
def make_fragment(x): def make_text_fragment(x):
#print("### making fragment for", x) #print("### making fragment for", x)
if type(x) is str: if type(x) is str:
return Fragment(x) return Fragment(x)
elif type(x) is pp.ParseResults or type(x) is list: elif type(x) is pp.ParseResults or type(x) is list:
#print(f'converting {type(x).__name__} to Fragment') #print(f'converting {type(x).__name__} to Fragment')
return Fragment(' '.join([s for s in x])) return Fragment(' '.join([s for s in x]))
else: else:
raise PromptParser.ParsingException("Cannot make fragment from " + str(x)) raise PromptParser.ParsingException("Cannot make fragment from " + str(x))
quoted_fragment = pp.Forward() def parse_fragment_str(x, in_quotes: bool=False, in_parens: bool=False):
parenthesized_fragment = pp.Forward() fragment_string = x[0]
print(f"parsing fragment string \"{fragment_string}\"")
if len(fragment_string.strip()) == 0:
return Fragment('')
def parse_fragment_str(x): if in_quotes:
#print("parsing fragment string", x) # escape unescaped quotes
if len(x[0].strip()) == 0: fragment_string = fragment_string.replace('"', '\\"')
return Fragment('')
fragment_parser = pp.Group(pp.OneOrMore(attention | (greedy_word.set_parse_action(make_fragment))))
fragment_parser.set_name('word_or_attention')
result = fragment_parser.parse_string(x[0])
#result = (pp.OneOrMore(attention | unquoted_word) + pp.StringEnd()).parse_string(x[0])
#print("parsed to", result)
return result
quoted_fragment << pp.QuotedString(quote_char='"', esc_char='\\') #fragment_parser = pp.Group(pp.OneOrMore(attention | cross_attention_substitute | (greedy_word.set_parse_action(make_text_fragment))))
quoted_fragment.set_parse_action(parse_fragment_str).set_name('quoted_fragment') result = pp.Group(pp.MatchFirst([
pp.OneOrMore(prompt_part | quoted_fragment),
pp.Empty().set_parse_action(make_text_fragment) + pp.StringEnd()
])).set_name('rr').set_debug(False).parse_string(fragment_string)
#result = (pp.OneOrMore(attention | unquoted_word) + pp.StringEnd()).parse_string(x[0])
#print("parsed to", result)
return result
self_unescaping_escaped_quote = pp.Literal('\\"').set_parse_action(lambda x: '"') quoted_fragment << pp.QuotedString(quote_char='"', esc_char=None, esc_quote='\\"')
self_unescaping_escaped_lparen = pp.Literal('\\(').set_parse_action(lambda x: '(') quoted_fragment.set_parse_action(lambda x: parse_fragment_str(x, in_quotes=True)).set_name('quoted_fragment')
self_unescaping_escaped_rparen = pp.Literal('\\)').set_parse_action(lambda x: ')')
def not_ends_with_swap(x): escaped_quote = pp.Literal('\\"')#.set_parse_action(lambda x: '"')
#print("trying to match:", x) escaped_lparen = pp.Literal('\\(')#.set_parse_action(lambda x: '(')
return not x[0].endswith('.swap') escaped_rparen = pp.Literal('\\)')#.set_parse_action(lambda x: ')')
escaped_backslash = pp.Literal('\\\\')#.set_parse_action(lambda x: '"')
unquoted_fragment = pp.Combine(pp.OneOrMore( def not_ends_with_swap(x):
self_unescaping_escaped_rparen | self_unescaping_escaped_lparen | self_unescaping_escaped_quote | #print("trying to match:", x)
pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"()'))) return not x[0].endswith('.swap')
unquoted_fragment.set_parse_action(make_fragment).set_name('unquoted_fragment').set_debug(True)
#print(unquoted_fragment.parse_string("cat.swap(dog)"))
parenthesized_fragment << pp.MatchFirst([ unquoted_fragment = pp.Combine(pp.OneOrMore(
(lparen + quoted_fragment.copy().set_parse_action(parse_fragment_str).set_debug(False) + rparen).set_name('-quoted_paren_internal').set_debug(False), escaped_rparen | escaped_lparen | escaped_quote | escaped_backslash |
(lparen + rparen).set_parse_action(lambda x: make_fragment('')).set_name('-()').set_debug(False), pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"()')))
(lparen + pp.Combine(pp.OneOrMore( unquoted_fragment.set_parse_action(make_text_fragment).set_name('unquoted_fragment').set_debug(False)
pp.Literal('\\"').set_debug(False).set_parse_action(lambda x: '"') | #print(unquoted_fragment.parse_string("cat.swap(dog)"))
pp.Literal('\\(').set_debug(False).set_parse_action(lambda x: '(') |
pp.Literal('\\)').set_debug(False).set_parse_action(lambda x: ')') |
pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"()') |
pp.Word(string.whitespace)
)).set_name('--combined').set_parse_action(parse_fragment_str).set_debug(False) + rparen)]).set_name('-unquoted_paren_internal').set_debug(False)
parenthesized_fragment.set_name('parenthesized_fragment').set_debug(False)
debug_attention = False parenthesized_fragment << pp.Or([
# attention control of the form +(phrase) / -(phrase) / <weight>(phrase) (lparen + quoted_fragment.copy().set_parse_action(lambda x: parse_fragment_str(x, in_quotes=True)).set_debug(False) + rparen).set_name('-quoted_paren_internal').set_debug(False),
# phrase can be multiple words, can have multiple +/- signs to increase the effect or type a floating point or integer weight (lparen + rparen).set_parse_action(lambda x: make_text_fragment('')).set_name('-()').set_debug(False),
attention_head = (number | pp.Word('+') | pp.Word('-'))\ (lparen + pp.Combine(pp.OneOrMore(
.set_name("attention_head")\ escaped_quote | escaped_lparen | escaped_rparen | escaped_backslash |
.set_debug(False) pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"()') |
word_inside_attention = pp.Combine(pp.OneOrMore( pp.Word(string.whitespace)
pp.Literal('\\)') | pp.Literal('\\(') | pp.Literal('\\"') | )).set_name('--combined').set_parse_action(lambda x: parse_fragment_str(x, in_parens=True)).set_debug(False) + rparen)]).set_name('-unquoted_paren_internal').set_debug(False)
pp.Word(pp.printables, exclude_chars=string.whitespace + '\\()"') parenthesized_fragment.set_name('parenthesized_fragment').set_debug(False)
)).set_name('word_inside_attention')
attention_with_parens = pp.Forward()
attention_with_parens_delimited_list = pp.delimited_list(pp.Or([
quoted_fragment.copy().set_debug(debug_attention),
attention.copy().set_debug(debug_attention),
word_inside_attention.set_debug(debug_attention)]).set_name('delim_inner').set_debug(debug_attention),
delim=string.whitespace)
# have to disable ignore_expr here to prevent pyparsing from stripping off quote marks
attention_with_parens_body = pp.nested_expr(content=attention_with_parens_delimited_list,
ignore_expr=None#((pp.Literal("\\(") | pp.Literal('\\)')))
)
attention_with_parens_body.set_debug(debug_attention)
attention_with_parens << (attention_head + attention_with_parens_body)
attention_with_parens.set_name('attention_with_parens').set_debug(debug_attention)
attention_without_parens = (pp.Word('+') | pp.Word('-')) + (quoted_fragment | word_inside_attention) debug_attention = False
attention_without_parens.set_name('attention_without_parens').set_debug(debug_attention) # attention control of the form +(phrase) / -(phrase) / <weight>(phrase)
# phrase can be multiple words, can have multiple +/- signs to increase the effect or type a floating point or integer weight
attention_head = (number | pp.Word('+') | pp.Word('-'))\
.set_name("attention_head")\
.set_debug(False)
word_inside_attention = pp.Combine(pp.OneOrMore(
pp.Literal('\\)') | pp.Literal('\\(') | pp.Literal('\\"') |
pp.Word(pp.printables, exclude_chars=string.whitespace + '\\()"')
)).set_name('word_inside_attention')
attention_with_parens = pp.Forward()
attention << (attention_with_parens | attention_without_parens) attention_with_parens_delimited_list = pp.OneOrMore(pp.Or([
quoted_fragment.copy().set_debug(debug_attention),
attention.copy().set_debug(debug_attention),
cross_attention_substitute,
word_inside_attention.set_debug(debug_attention)
#pp.White()
]).set_name('delim_inner').set_debug(debug_attention))
# have to disable ignore_expr here to prevent pyparsing from stripping off quote marks
attention_with_parens_body = pp.nested_expr(content=attention_with_parens_delimited_list,
ignore_expr=None#((pp.Literal("\\(") | pp.Literal('\\)')))
)
attention_with_parens_body.set_debug(debug_attention)
attention_with_parens << (attention_head + attention_with_parens_body)
attention_with_parens.set_name('attention_with_parens').set_debug(debug_attention)
def make_attention(x): attention_without_parens = (pp.Word('+') | pp.Word('-')) + (quoted_fragment | word_inside_attention)
#print("making Attention from", x) attention_without_parens.set_name('attention_without_parens').set_debug(debug_attention)
weight = 1
# number(str)
if type(x[0]) is float or type(x[0]) is int:
weight = float(x[0])
# +(str) or -(str) or +str or -str
elif type(x[0]) is str:
base = self.attention_plus_base if x[0][0] == '+' else self.attention_minus_base
weight = pow(base, len(x[0]))
if type(x[1]) is list or type(x[1]) is pp.ParseResults:
return Attention(weight=weight, children=[(Fragment(x) if type(x) is str else x) for x in x[1]])
elif type(x[1]) is str:
return Attention(weight=weight, children=[Fragment(x[1])])
elif type(x[1]) is Fragment:
return Attention(weight=weight, children=[x[1]])
raise PromptParser.ParsingException(f"Don't know how to make attention with children {x[1]}")
attention_with_parens.set_parse_action(make_attention) attention << (attention_with_parens | attention_without_parens)
attention_without_parens.set_parse_action(make_attention) attention.set_name('attention')
# cross-attention control def make_attention(x):
empty_string = ((lparen + rparen) | #print("making Attention from", x)
pp.Literal('""').suppress() | weight = 1
(lparen + pp.Literal('""').suppress() + rparen) # number(str)
).set_parse_action(lambda x: Fragment("")) if type(x[0]) is float or type(x[0]) is int:
empty_string.set_name('empty_string') weight = float(x[0])
# +(str) or -(str) or +str or -str
elif type(x[0]) is str:
base = attention_plus_base if x[0][0] == '+' else attention_minus_base
weight = pow(base, len(x[0]))
if type(x[1]) is list or type(x[1]) is pp.ParseResults:
return Attention(weight=weight, children=[(Fragment(x) if type(x) is str else x) for x in x[1]])
elif type(x[1]) is str:
return Attention(weight=weight, children=[Fragment(x[1])])
elif type(x[1]) is Fragment:
return Attention(weight=weight, children=[x[1]])
raise PromptParser.ParsingException(f"Don't know how to make attention with children {x[1]}")
# cross attention control attention_with_parens.set_parse_action(make_attention)
debug_cross_attention_control = False attention_without_parens.set_parse_action(make_attention)
original_fragment = pp.Or([empty_string.set_debug(debug_cross_attention_control),
quoted_fragment.set_debug(debug_cross_attention_control),
parenthesized_fragment.set_debug(debug_cross_attention_control),
pp.Word(pp.printables, exclude_chars=string.whitespace + '.').set_parse_action(make_fragment) + pp.FollowedBy(".swap")
])
edited_fragment = parenthesized_fragment
cross_attention_substitute = original_fragment + pp.Literal(".swap").suppress() + edited_fragment
original_fragment.set_name('original_fragment').set_debug(debug_cross_attention_control) # cross-attention control
edited_fragment.set_name('edited_fragment').set_debug(debug_cross_attention_control) empty_string = ((lparen + rparen) |
cross_attention_substitute.set_name('cross_attention_substitute').set_debug(debug_cross_attention_control) pp.Literal('""').suppress() |
(lparen + pp.Literal('""').suppress() + rparen)
).set_parse_action(lambda x: Fragment(""))
empty_string.set_name('empty_string')
def make_cross_attention_substitute(x): # cross attention control
#print("making cacs for", x) debug_cross_attention_control = False
cacs = CrossAttentionControlSubstitute(x[0], x[1]) original_fragment = pp.Or([empty_string.set_debug(debug_cross_attention_control),
#print("made", cacs) quoted_fragment.set_debug(debug_cross_attention_control),
return cacs parenthesized_fragment.set_debug(debug_cross_attention_control),
cross_attention_substitute.set_parse_action(make_cross_attention_substitute) pp.Word(pp.printables, exclude_chars=string.whitespace + '.').set_parse_action(make_text_fragment) + pp.FollowedBy(".swap")
])
edited_fragment = parenthesized_fragment
cross_attention_substitute << original_fragment + pp.Literal(".swap").suppress() + edited_fragment
original_fragment.set_name('original_fragment').set_debug(debug_cross_attention_control)
edited_fragment.set_name('edited_fragment').set_debug(debug_cross_attention_control)
cross_attention_substitute.set_name('cross_attention_substitute').set_debug(debug_cross_attention_control)
def make_cross_attention_substitute(x):
#print("making cacs for", x)
cacs = CrossAttentionControlSubstitute(x[0], x[1])
#print("made", cacs)
return cacs
cross_attention_substitute.set_parse_action(make_cross_attention_substitute)
# simple fragments of text
# use Or to match the longest
prompt_part << pp.MatchFirst([
cross_attention_substitute,
attention,
unquoted_fragment,
lparen + unquoted_fragment + rparen # matches case where user has +(term) and just deletes the +
])
prompt_part.set_debug(False)
prompt_part.set_name("prompt_part")
empty = (
(lparen + pp.ZeroOrMore(pp.Word(string.whitespace)) + rparen) |
(quotes + pp.ZeroOrMore(pp.Word(string.whitespace)) + quotes)).set_debug(False).set_name('empty')
# root prompt definition
prompt = ((pp.OneOrMore(prompt_part | quoted_fragment) | empty) + pp.StringEnd()) \
.set_parse_action(lambda x: Prompt(x))
# simple fragments of text # weighted blend of prompts
# use Or to match the longest # ("promptA", "promptB").blend(a, b) where "promptA" and "promptB" are valid prompts and a and b are float or
prompt_part = pp.Or([ # int weights.
cross_attention_substitute, # can specify more terms eg ("promptA", "promptB", "promptC").blend(a,b,c)
attention,
quoted_fragment,
unquoted_fragment,
lparen + unquoted_fragment + rparen # matches case where user has +(term) and just deletes the +
])
prompt_part.set_debug(False)
prompt_part.set_name("prompt_part")
empty = ( def make_prompt_from_quoted_string(x):
(lparen + pp.ZeroOrMore(pp.Word(string.whitespace)) + rparen) | #print(' got quoted prompt', x)
(quotes + pp.ZeroOrMore(pp.Word(string.whitespace)) + quotes)).set_debug(False).set_name('empty')
# root prompt definition x_unquoted = x[0][1:-1]
prompt = ((pp.OneOrMore(prompt_part) | empty) + pp.StringEnd()) \ if len(x_unquoted.strip()) == 0:
.set_parse_action(lambda x: Prompt(x)) # print(' b : just an empty string')
return Prompt([Fragment('')])
# print(' b parsing ', c_unquoted)
x_parsed = prompt.parse_string(x_unquoted)
#print(" quoted prompt was parsed to", type(x_parsed),":", x_parsed)
return x_parsed[0]
quoted_prompt = pp.dbl_quoted_string.set_parse_action(make_prompt_from_quoted_string)
quoted_prompt.set_name('quoted_prompt')
blend_terms = pp.delimited_list(quoted_prompt).set_name('blend_terms')
blend_weights = pp.delimited_list(number).set_name('blend_weights')
blend = pp.Group(lparen + pp.Group(blend_terms) + rparen
+ pp.Literal(".blend").suppress()
+ lparen + pp.Group(blend_weights) + rparen).set_name('blend')
blend.set_debug(False)
blend.set_parse_action(lambda x: Blend(prompts=x[0][0], weights=x[0][1]))
# weighted blend of prompts conjunction_terms = blend_terms.copy().set_name('conjunction_terms')
# ("promptA", "promptB").blend(a, b) where "promptA" and "promptB" are valid prompts and a and b are float or conjunction_weights = blend_weights.copy().set_name('conjunction_weights')
# int weights. conjunction_with_parens_and_quotes = pp.Group(lparen + pp.Group(conjunction_terms) + rparen
# can specify more terms eg ("promptA", "promptB", "promptC").blend(a,b,c) + pp.Literal(".and").suppress()
+ lparen + pp.Optional(pp.Group(conjunction_weights)) + rparen).set_name('conjunction')
def make_conjunction(x):
parts_raw = x[0][0]
weights = x[0][1] if len(x[0])>1 else [1.0]*len(parts_raw)
parts = [part for part in parts_raw]
return Conjunction(parts, weights)
conjunction_with_parens_and_quotes.set_parse_action(make_conjunction)
def make_prompt_from_quoted_string(x): implicit_conjunction = pp.OneOrMore(blend | prompt).set_name('implicit_conjunction')
#print(' got quoted prompt', x) implicit_conjunction.set_parse_action(lambda x: Conjunction(x))
x_unquoted = x[0][1:-1] conjunction = conjunction_with_parens_and_quotes | implicit_conjunction
if len(x_unquoted.strip()) == 0: conjunction.set_debug(False)
# print(' b : just an empty string')
return Prompt([Fragment('')])
# print(' b parsing ', c_unquoted)
x_parsed = prompt.parse_string(x_unquoted)
#print(" quoted prompt was parsed to", type(x_parsed),":", x_parsed)
return x_parsed[0]
quoted_prompt = pp.dbl_quoted_string.set_parse_action(make_prompt_from_quoted_string) # top-level is a conjunction of one or more blends or prompts
quoted_prompt.set_name('quoted_prompt') return conjunction
blend_terms = pp.delimited_list(quoted_prompt).set_name('blend_terms')
blend_weights = pp.delimited_list(number).set_name('blend_weights')
blend = pp.Group(lparen + pp.Group(blend_terms) + rparen
+ pp.Literal(".blend").suppress()
+ lparen + pp.Group(blend_weights) + rparen).set_name('blend')
blend.set_debug(False)
blend.set_parse_action(lambda x: Blend(prompts=x[0][0], weights=x[0][1]))
conjunction_terms = blend_terms.copy().set_name('conjunction_terms')
conjunction_weights = blend_weights.copy().set_name('conjunction_weights')
conjunction_with_parens_and_quotes = pp.Group(lparen + pp.Group(conjunction_terms) + rparen
+ pp.Literal(".and").suppress()
+ lparen + pp.Optional(pp.Group(conjunction_weights)) + rparen).set_name('conjunction')
def make_conjunction(x):
parts_raw = x[0][0]
weights = x[0][1] if len(x[0])>1 else [1.0]*len(parts_raw)
parts = [part for part in parts_raw]
return Conjunction(parts, weights)
conjunction_with_parens_and_quotes.set_parse_action(make_conjunction)
implicit_conjunction = pp.OneOrMore(blend | prompt).set_name('implicit_conjunction')
implicit_conjunction.set_parse_action(lambda x: Conjunction(x))
conjunction = conjunction_with_parens_and_quotes | implicit_conjunction
conjunction.set_debug(False)
# top-level is a conjunction of one or more blends or prompts
return conjunction

View File

@ -77,6 +77,15 @@ class PromptParserTestCase(unittest.TestCase):
def test_complex_conjunction(self): def test_complex_conjunction(self):
self.assertEqual(Conjunction([FlattenedPrompt([("mountain man", 1.0)]), FlattenedPrompt([("a person with a hat", 1.0), ("riding a bicycle", pow(1.1,2))])], weights=[0.5, 0.5]), self.assertEqual(Conjunction([FlattenedPrompt([("mountain man", 1.0)]), FlattenedPrompt([("a person with a hat", 1.0), ("riding a bicycle", pow(1.1,2))])], weights=[0.5, 0.5]),
parse_prompt("(\"mountain man\", \"a person with a hat ++(riding a bicycle)\").and(0.5, 0.5)")) parse_prompt("(\"mountain man\", \"a person with a hat ++(riding a bicycle)\").and(0.5, 0.5)"))
self.assertEqual(Conjunction([FlattenedPrompt([("mountain man", 1.0)]),
FlattenedPrompt([("a person with a hat", 1.0),
("riding a", 1.1*1.1),
CrossAttentionControlSubstitute(
[Fragment("bicycle", pow(1.1,2))],
[Fragment("skateboard", pow(1.1,2))])
])
], weights=[0.5, 0.5]),
parse_prompt("(\"mountain man\", \"a person with a hat ++(riding a bicycle.swap(skateboard))\").and(0.5, 0.5)"))
def test_badly_formed(self): def test_badly_formed(self):
def make_untouched_prompt(prompt): def make_untouched_prompt(prompt):
@ -309,8 +318,7 @@ class PromptParserTestCase(unittest.TestCase):
parse_prompt('mountain (\(\(\().swap(m\(on\)\)key)')) parse_prompt('mountain (\(\(\().swap(m\(on\)\)key)'))
def test_single(self): def test_single(self):
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('"man', 1)], [Fragment('monkey', 1)])])]), self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mou)ntain, \"man (wit(h a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy +("mou\)ntain, \\\"man\" \(wit\(h a 2.0(beard))'))
parse_prompt('mountain (\\"man).swap("monkey")'))
if __name__ == '__main__': if __name__ == '__main__':