mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
wip re-writing parts of prompt parser
This commit is contained in:
parent
c9d27634b4
commit
da223dfe81
@ -1,3 +1,5 @@
|
||||
import string
|
||||
|
||||
import pyparsing
|
||||
import pyparsing as pp
|
||||
from pyparsing import original_text_for
|
||||
@ -200,8 +202,8 @@ class PromptParser():
|
||||
elif type(node) is Fragment:
|
||||
results += [Fragment(node.text, node.weight*weight_scale)]
|
||||
elif type(node) is CrossAttentionControlSubstitute:
|
||||
original = flatten_internal(node.original, weight_scale, [], ' CAo ')
|
||||
edited = flatten_internal(node.edited, weight_scale, [], ' CAe ')
|
||||
original = flatten_internal(node.original, weight_scale, [], prefix + ' CAo ')
|
||||
edited = flatten_internal(node.edited, weight_scale, [], prefix + ' CAe ')
|
||||
results += [CrossAttentionControlSubstitute(original, edited)]
|
||||
elif type(node) is Blend:
|
||||
flattened_subprompts = []
|
||||
@ -236,24 +238,46 @@ class PromptParser():
|
||||
|
||||
lparen = pp.Literal("(").suppress()
|
||||
rparen = pp.Literal(")").suppress()
|
||||
quotes = pp.Literal('"').suppress()
|
||||
|
||||
# accepts int or float notation, always maps to float
|
||||
number = pyparsing.pyparsing_common.real | pp.Word(pp.nums).set_parse_action(pp.token_map(float))
|
||||
SPACE_CHARS = ' \t\n'
|
||||
|
||||
prompt_part = pp.Forward()
|
||||
word = pp.Word(pp.printables).set_parse_action(lambda x: Fragment(' '.join([s for s in x])))
|
||||
word.set_name("word")
|
||||
word.set_debug(False)
|
||||
SPACE_CHARS = string.whitespace
|
||||
|
||||
def make_fragment(x):
|
||||
#print("### making fragment for", x)
|
||||
if type(x) is str:
|
||||
return Fragment(x)
|
||||
elif type(x) is pp.ParseResults or type(x) is list:
|
||||
#print(f'converting {x} to Fragment')
|
||||
return Fragment(' '.join([s for s in x]))
|
||||
else:
|
||||
raise PromptParser.ParsingException("Cannot make fragment from " + str(x))
|
||||
|
||||
def parse_fragment_str(x):
|
||||
return make_fragment(x)
|
||||
|
||||
quoted_fragment = pp.QuotedString(quote_char='"', esc_char='\\')
|
||||
quoted_fragment.set_parse_action(parse_fragment_str).set_name('quoted_fragment')
|
||||
|
||||
unquoted_fragment = pp.Combine(pp.OneOrMore(
|
||||
pp.Literal('\\"').set_debug(False) |
|
||||
pp.Literal('\\').set_debug(False) |
|
||||
pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"')
|
||||
))
|
||||
unquoted_fragment.set_parse_action(parse_fragment_str).set_name('unquoted_fragment')
|
||||
|
||||
parenthesized_fragment = \
|
||||
(lparen + quoted_fragment.set_debug(True) + rparen).set_name('quoted_paren_internal') | \
|
||||
(lparen + rparen).set_parse_action(lambda x: make_fragment('')) | \
|
||||
(lparen + pp.Combine(pp.OneOrMore(
|
||||
pp.Literal('\\)').set_debug(False) |
|
||||
pp.Literal('\\').set_debug(False) |
|
||||
pp.Word(pp.printables, exclude_chars=string.whitespace + '\\)') |
|
||||
pp.Word(string.whitespace)
|
||||
)) + rparen).set_parse_action(parse_fragment_str).set_name('unquoted_paren_internal').set_debug(True)
|
||||
parenthesized_fragment.set_name('parenthesized_fragment').set_debug(True)
|
||||
|
||||
# attention control of the form +(phrase) / -(phrase) / <weight>(phrase)
|
||||
# phrase can be multiple words, can have multiple +/- signs to increase the effect or type a floating point or integer weight
|
||||
attention = pp.Forward()
|
||||
@ -303,41 +327,35 @@ class PromptParser():
|
||||
pp.Literal('""').suppress() |
|
||||
(lparen + pp.Literal('""').suppress() + rparen)
|
||||
).set_parse_action(lambda x: Fragment(""))
|
||||
empty_string.set_name('empty_string')
|
||||
|
||||
original_words = (
|
||||
(lparen + pp.Literal('"').suppress() + (pp.OneOrMore(attention) | pp.CharsNotIn('"').set_parse_action(make_fragment)) + pp.Literal('"').suppress() + rparen).set_name('term1').set_debug(False) |
|
||||
(pp.Literal('"').suppress() + (pp.OneOrMore(attention) | pp.CharsNotIn('"').set_parse_action(make_fragment)) + pp.Literal('"').suppress()).set_name('term2').set_debug(False) |
|
||||
(lparen + (pp.OneOrMore(attention) | pp.CharsNotIn(')').set_parse_action(make_fragment)) + rparen).set_name('term3').set_debug(False)
|
||||
).set_name('original_words')
|
||||
edited_words = (
|
||||
(pp.Literal('"').suppress() + (pp.OneOrMore(attention) | pp.CharsNotIn('"').set_parse_action(make_fragment)) + pp.Literal('"').suppress()).set_name('termA').set_debug(False) |
|
||||
pp.Literal('""').suppress().set_parse_action(lambda x: Fragment("")) |
|
||||
(pp.OneOrMore(attention) | pp.CharsNotIn(')').set_parse_action(make_fragment)).set_name('termB').set_debug(True)
|
||||
).set_name('edited_words')
|
||||
cross_attention_substitute = (empty_string | original_words) + \
|
||||
pp.Literal(".swap").suppress() + \
|
||||
(empty_string | (lparen + edited_words + rparen)
|
||||
)
|
||||
cross_attention_substitute.set_name('cross_attention_substitute')
|
||||
original_fragment = empty_string | quoted_fragment | parenthesized_fragment | unquoted_fragment
|
||||
edited_fragment = parenthesized_fragment
|
||||
cross_attention_substitute = original_fragment + pp.Literal(".swap").suppress() + edited_fragment
|
||||
|
||||
cross_attention_substitute.set_name('cross_attention_substitute').set_debug(True)
|
||||
|
||||
def make_cross_attention_substitute(x):
|
||||
#print("making cacs for", x)
|
||||
print("making cacs for", x)
|
||||
cacs = CrossAttentionControlSubstitute(x[0], x[1])
|
||||
#print("made", cacs)
|
||||
print("made", cacs)
|
||||
return cacs
|
||||
|
||||
cross_attention_substitute.set_parse_action(make_cross_attention_substitute)
|
||||
|
||||
# simple fragments of text
|
||||
prompt_part << (cross_attention_substitute
|
||||
| attention
|
||||
| word
|
||||
)
|
||||
prompt_part = (
|
||||
cross_attention_substitute
|
||||
| attention
|
||||
| quoted_fragment
|
||||
| unquoted_fragment
|
||||
)
|
||||
prompt_part.set_debug(False)
|
||||
prompt_part.set_name("prompt_part")
|
||||
|
||||
empty = ((lparen + rparen) | (quotes + quotes)).suppress()
|
||||
|
||||
# root prompt definition
|
||||
prompt = pp.Group(pp.OneOrMore(prompt_part))\
|
||||
prompt = (pp.Group(pp.OneOrMore(prompt_part) | empty) + pp.StringEnd()) \
|
||||
.set_parse_action(lambda x: Prompt(x[0]))
|
||||
|
||||
# weighted blend of prompts
|
||||
|
@ -201,6 +201,9 @@ class PromptParserTestCase(unittest.TestCase):
|
||||
Fragment(',', 1), Fragment('fire', 2.0)])])
|
||||
self.assertEqual(flames_to_trees_fire, parse_prompt('"0.5(fire 0.5(flames))".swap("0.7(trees) houses"), 2.0(fire)'))
|
||||
|
||||
def test_single(self):
|
||||
print(parse_prompt('fire (trees and houses).swap("flames")'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user