mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
wip re-writing parts of prompt parser
This commit is contained in:
parent
c9d27634b4
commit
da223dfe81
@ -1,3 +1,5 @@
|
|||||||
|
import string
|
||||||
|
|
||||||
import pyparsing
|
import pyparsing
|
||||||
import pyparsing as pp
|
import pyparsing as pp
|
||||||
from pyparsing import original_text_for
|
from pyparsing import original_text_for
|
||||||
@ -200,8 +202,8 @@ class PromptParser():
|
|||||||
elif type(node) is Fragment:
|
elif type(node) is Fragment:
|
||||||
results += [Fragment(node.text, node.weight*weight_scale)]
|
results += [Fragment(node.text, node.weight*weight_scale)]
|
||||||
elif type(node) is CrossAttentionControlSubstitute:
|
elif type(node) is CrossAttentionControlSubstitute:
|
||||||
original = flatten_internal(node.original, weight_scale, [], ' CAo ')
|
original = flatten_internal(node.original, weight_scale, [], prefix + ' CAo ')
|
||||||
edited = flatten_internal(node.edited, weight_scale, [], ' CAe ')
|
edited = flatten_internal(node.edited, weight_scale, [], prefix + ' CAe ')
|
||||||
results += [CrossAttentionControlSubstitute(original, edited)]
|
results += [CrossAttentionControlSubstitute(original, edited)]
|
||||||
elif type(node) is Blend:
|
elif type(node) is Blend:
|
||||||
flattened_subprompts = []
|
flattened_subprompts = []
|
||||||
@ -236,24 +238,46 @@ class PromptParser():
|
|||||||
|
|
||||||
lparen = pp.Literal("(").suppress()
|
lparen = pp.Literal("(").suppress()
|
||||||
rparen = pp.Literal(")").suppress()
|
rparen = pp.Literal(")").suppress()
|
||||||
|
quotes = pp.Literal('"').suppress()
|
||||||
|
|
||||||
# accepts int or float notation, always maps to float
|
# accepts int or float notation, always maps to float
|
||||||
number = pyparsing.pyparsing_common.real | pp.Word(pp.nums).set_parse_action(pp.token_map(float))
|
number = pyparsing.pyparsing_common.real | pp.Word(pp.nums).set_parse_action(pp.token_map(float))
|
||||||
SPACE_CHARS = ' \t\n'
|
SPACE_CHARS = string.whitespace
|
||||||
|
|
||||||
prompt_part = pp.Forward()
|
|
||||||
word = pp.Word(pp.printables).set_parse_action(lambda x: Fragment(' '.join([s for s in x])))
|
|
||||||
word.set_name("word")
|
|
||||||
word.set_debug(False)
|
|
||||||
|
|
||||||
def make_fragment(x):
|
def make_fragment(x):
|
||||||
#print("### making fragment for", x)
|
#print("### making fragment for", x)
|
||||||
if type(x) is str:
|
if type(x) is str:
|
||||||
return Fragment(x)
|
return Fragment(x)
|
||||||
elif type(x) is pp.ParseResults or type(x) is list:
|
elif type(x) is pp.ParseResults or type(x) is list:
|
||||||
|
#print(f'converting {x} to Fragment')
|
||||||
return Fragment(' '.join([s for s in x]))
|
return Fragment(' '.join([s for s in x]))
|
||||||
else:
|
else:
|
||||||
raise PromptParser.ParsingException("Cannot make fragment from " + str(x))
|
raise PromptParser.ParsingException("Cannot make fragment from " + str(x))
|
||||||
|
|
||||||
|
def parse_fragment_str(x):
|
||||||
|
return make_fragment(x)
|
||||||
|
|
||||||
|
quoted_fragment = pp.QuotedString(quote_char='"', esc_char='\\')
|
||||||
|
quoted_fragment.set_parse_action(parse_fragment_str).set_name('quoted_fragment')
|
||||||
|
|
||||||
|
unquoted_fragment = pp.Combine(pp.OneOrMore(
|
||||||
|
pp.Literal('\\"').set_debug(False) |
|
||||||
|
pp.Literal('\\').set_debug(False) |
|
||||||
|
pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"')
|
||||||
|
))
|
||||||
|
unquoted_fragment.set_parse_action(parse_fragment_str).set_name('unquoted_fragment')
|
||||||
|
|
||||||
|
parenthesized_fragment = \
|
||||||
|
(lparen + quoted_fragment.set_debug(True) + rparen).set_name('quoted_paren_internal') | \
|
||||||
|
(lparen + rparen).set_parse_action(lambda x: make_fragment('')) | \
|
||||||
|
(lparen + pp.Combine(pp.OneOrMore(
|
||||||
|
pp.Literal('\\)').set_debug(False) |
|
||||||
|
pp.Literal('\\').set_debug(False) |
|
||||||
|
pp.Word(pp.printables, exclude_chars=string.whitespace + '\\)') |
|
||||||
|
pp.Word(string.whitespace)
|
||||||
|
)) + rparen).set_parse_action(parse_fragment_str).set_name('unquoted_paren_internal').set_debug(True)
|
||||||
|
parenthesized_fragment.set_name('parenthesized_fragment').set_debug(True)
|
||||||
|
|
||||||
# attention control of the form +(phrase) / -(phrase) / <weight>(phrase)
|
# attention control of the form +(phrase) / -(phrase) / <weight>(phrase)
|
||||||
# phrase can be multiple words, can have multiple +/- signs to increase the effect or type a floating point or integer weight
|
# phrase can be multiple words, can have multiple +/- signs to increase the effect or type a floating point or integer weight
|
||||||
attention = pp.Forward()
|
attention = pp.Forward()
|
||||||
@ -303,41 +327,35 @@ class PromptParser():
|
|||||||
pp.Literal('""').suppress() |
|
pp.Literal('""').suppress() |
|
||||||
(lparen + pp.Literal('""').suppress() + rparen)
|
(lparen + pp.Literal('""').suppress() + rparen)
|
||||||
).set_parse_action(lambda x: Fragment(""))
|
).set_parse_action(lambda x: Fragment(""))
|
||||||
|
empty_string.set_name('empty_string')
|
||||||
|
|
||||||
original_words = (
|
original_fragment = empty_string | quoted_fragment | parenthesized_fragment | unquoted_fragment
|
||||||
(lparen + pp.Literal('"').suppress() + (pp.OneOrMore(attention) | pp.CharsNotIn('"').set_parse_action(make_fragment)) + pp.Literal('"').suppress() + rparen).set_name('term1').set_debug(False) |
|
edited_fragment = parenthesized_fragment
|
||||||
(pp.Literal('"').suppress() + (pp.OneOrMore(attention) | pp.CharsNotIn('"').set_parse_action(make_fragment)) + pp.Literal('"').suppress()).set_name('term2').set_debug(False) |
|
cross_attention_substitute = original_fragment + pp.Literal(".swap").suppress() + edited_fragment
|
||||||
(lparen + (pp.OneOrMore(attention) | pp.CharsNotIn(')').set_parse_action(make_fragment)) + rparen).set_name('term3').set_debug(False)
|
|
||||||
).set_name('original_words')
|
cross_attention_substitute.set_name('cross_attention_substitute').set_debug(True)
|
||||||
edited_words = (
|
|
||||||
(pp.Literal('"').suppress() + (pp.OneOrMore(attention) | pp.CharsNotIn('"').set_parse_action(make_fragment)) + pp.Literal('"').suppress()).set_name('termA').set_debug(False) |
|
|
||||||
pp.Literal('""').suppress().set_parse_action(lambda x: Fragment("")) |
|
|
||||||
(pp.OneOrMore(attention) | pp.CharsNotIn(')').set_parse_action(make_fragment)).set_name('termB').set_debug(True)
|
|
||||||
).set_name('edited_words')
|
|
||||||
cross_attention_substitute = (empty_string | original_words) + \
|
|
||||||
pp.Literal(".swap").suppress() + \
|
|
||||||
(empty_string | (lparen + edited_words + rparen)
|
|
||||||
)
|
|
||||||
cross_attention_substitute.set_name('cross_attention_substitute')
|
|
||||||
|
|
||||||
def make_cross_attention_substitute(x):
|
def make_cross_attention_substitute(x):
|
||||||
#print("making cacs for", x)
|
print("making cacs for", x)
|
||||||
cacs = CrossAttentionControlSubstitute(x[0], x[1])
|
cacs = CrossAttentionControlSubstitute(x[0], x[1])
|
||||||
#print("made", cacs)
|
print("made", cacs)
|
||||||
return cacs
|
return cacs
|
||||||
|
|
||||||
cross_attention_substitute.set_parse_action(make_cross_attention_substitute)
|
cross_attention_substitute.set_parse_action(make_cross_attention_substitute)
|
||||||
|
|
||||||
# simple fragments of text
|
# simple fragments of text
|
||||||
prompt_part << (cross_attention_substitute
|
prompt_part = (
|
||||||
|
cross_attention_substitute
|
||||||
| attention
|
| attention
|
||||||
| word
|
| quoted_fragment
|
||||||
|
| unquoted_fragment
|
||||||
)
|
)
|
||||||
prompt_part.set_debug(False)
|
prompt_part.set_debug(False)
|
||||||
prompt_part.set_name("prompt_part")
|
prompt_part.set_name("prompt_part")
|
||||||
|
|
||||||
|
empty = ((lparen + rparen) | (quotes + quotes)).suppress()
|
||||||
|
|
||||||
# root prompt definition
|
# root prompt definition
|
||||||
prompt = pp.Group(pp.OneOrMore(prompt_part))\
|
prompt = (pp.Group(pp.OneOrMore(prompt_part) | empty) + pp.StringEnd()) \
|
||||||
.set_parse_action(lambda x: Prompt(x[0]))
|
.set_parse_action(lambda x: Prompt(x[0]))
|
||||||
|
|
||||||
# weighted blend of prompts
|
# weighted blend of prompts
|
||||||
|
@ -201,6 +201,9 @@ class PromptParserTestCase(unittest.TestCase):
|
|||||||
Fragment(',', 1), Fragment('fire', 2.0)])])
|
Fragment(',', 1), Fragment('fire', 2.0)])])
|
||||||
self.assertEqual(flames_to_trees_fire, parse_prompt('"0.5(fire 0.5(flames))".swap("0.7(trees) houses"), 2.0(fire)'))
|
self.assertEqual(flames_to_trees_fire, parse_prompt('"0.5(fire 0.5(flames))".swap("0.7(trees) houses"), 2.0(fire)'))
|
||||||
|
|
||||||
|
def test_single(self):
|
||||||
|
print(parse_prompt('fire (trees and houses).swap("flames")'))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user