mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
parser working with basic escapes
This commit is contained in:
parent
da223dfe81
commit
79b4afeae7
@ -169,12 +169,16 @@ class PromptParser():
|
||||
|
||||
def flatten(self, root: Conjunction):
|
||||
|
||||
print("flattening", root)
|
||||
|
||||
def fuse_fragments(items):
|
||||
# print("fusing fragments in ", items)
|
||||
result = []
|
||||
for x in items:
|
||||
if issubclass(type(x), CrossAttentionControlledFragment):
|
||||
result.append(x)
|
||||
if type(x) is CrossAttentionControlSubstitute:
|
||||
original_fused = fuse_fragments(x.original)
|
||||
edited_fused = fuse_fragments(x.edited)
|
||||
result.append(CrossAttentionControlSubstitute(original_fused, edited_fused))
|
||||
else:
|
||||
last_weight = result[-1].weight \
|
||||
if (len(result) > 0 and not issubclass(type(result[-1]), CrossAttentionControlledFragment)) \
|
||||
@ -221,10 +225,9 @@ class PromptParser():
|
||||
#print(prefix + "after flattening Prompt, results is", results)
|
||||
else:
|
||||
raise PromptParser.ParsingException(f"unhandled node type {type(node)} when flattening {node}")
|
||||
print(prefix + "-> after flattening", type(node), "results is", results)
|
||||
print(prefix + "-> after flattening", type(node).__name__, "results is", results)
|
||||
return results
|
||||
|
||||
#print("flattening", root)
|
||||
|
||||
flattened_parts = []
|
||||
for part in root.prompts:
|
||||
@ -244,6 +247,8 @@ class PromptParser():
|
||||
number = pyparsing.pyparsing_common.real | pp.Word(pp.nums).set_parse_action(pp.token_map(float))
|
||||
SPACE_CHARS = string.whitespace
|
||||
|
||||
attention = pp.Forward()
|
||||
|
||||
def make_fragment(x):
|
||||
#print("### making fragment for", x)
|
||||
if type(x) is str:
|
||||
@ -254,33 +259,44 @@ class PromptParser():
|
||||
else:
|
||||
raise PromptParser.ParsingException("Cannot make fragment from " + str(x))
|
||||
|
||||
unquoted_fragment = pp.Forward()
|
||||
quoted_fragment = pp.Forward()
|
||||
parenthesized_fragment = pp.Forward()
|
||||
|
||||
def parse_fragment_str(x):
|
||||
return make_fragment(x)
|
||||
print("parsing", x)
|
||||
if len(x[0].strip()) == 0:
|
||||
return Fragment('')
|
||||
fragment_parser = pp.Group(pp.OneOrMore(attention | pp.Word(pp.printables, exclude_chars=string.whitespace).set_parse_action(make_fragment)))
|
||||
fragment_parser.set_name('word_or_attention')
|
||||
result = fragment_parser.parse_string(x[0])
|
||||
#result = (pp.OneOrMore(attention | unquoted_fragment) + pp.StringEnd()).parse_string(x[0])
|
||||
print("parsed to", result)
|
||||
return result
|
||||
|
||||
quoted_fragment = pp.QuotedString(quote_char='"', esc_char='\\')
|
||||
quoted_fragment.set_parse_action(parse_fragment_str).set_name('quoted_fragment')
|
||||
quoted_fragment << pp.QuotedString(quote_char='"', esc_char='\\')
|
||||
quoted_fragment.set_parse_action(make_fragment).set_name('quoted_fragment')
|
||||
|
||||
unquoted_fragment = pp.Combine(pp.OneOrMore(
|
||||
unquoted_fragment << pp.Combine(pp.OneOrMore(
|
||||
pp.Literal('\\"').set_debug(False) |
|
||||
pp.Literal('\\').set_debug(False) |
|
||||
pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"')
|
||||
))
|
||||
unquoted_fragment.set_parse_action(parse_fragment_str).set_name('unquoted_fragment')
|
||||
unquoted_fragment.set_parse_action(make_fragment).set_name('unquoted_fragment')
|
||||
|
||||
parenthesized_fragment = \
|
||||
(lparen + quoted_fragment.set_debug(True) + rparen).set_name('quoted_paren_internal') | \
|
||||
(lparen + rparen).set_parse_action(lambda x: make_fragment('')) | \
|
||||
parenthesized_fragment << pp.Or([
|
||||
(lparen + quoted_fragment.set_parse_action(parse_fragment_str).set_debug(True) + rparen).set_name('-quoted_paren_internal').set_debug(True),
|
||||
(lparen + rparen).set_parse_action(lambda x: make_fragment('')).set_name('-()').set_debug(True),
|
||||
(lparen + pp.Combine(pp.OneOrMore(
|
||||
pp.Literal('\\)').set_debug(False) |
|
||||
pp.Literal('\\').set_debug(False) |
|
||||
pp.Word(pp.printables, exclude_chars=string.whitespace + '\\)') |
|
||||
pp.Word(string.whitespace)
|
||||
)) + rparen).set_parse_action(parse_fragment_str).set_name('unquoted_paren_internal').set_debug(True)
|
||||
)).set_name('--combined').set_parse_action(parse_fragment_str).set_debug(True) + rparen)]).set_name('-unquoted_paren_internal').set_debug(True)
|
||||
parenthesized_fragment.set_name('parenthesized_fragment').set_debug(True)
|
||||
|
||||
# attention control of the form +(phrase) / -(phrase) / <weight>(phrase)
|
||||
# phrase can be multiple words, can have multiple +/- signs to increase the effect or type a floating point or integer weight
|
||||
attention = pp.Forward()
|
||||
attention_head = (number | pp.Word('+') | pp.Word('-'))\
|
||||
.set_name("attention_head")\
|
||||
.set_debug(False)
|
||||
@ -352,7 +368,9 @@ class PromptParser():
|
||||
prompt_part.set_debug(False)
|
||||
prompt_part.set_name("prompt_part")
|
||||
|
||||
empty = ((lparen + rparen) | (quotes + quotes)).suppress()
|
||||
empty = (
|
||||
(lparen + pp.ZeroOrMore(pp.Word(string.whitespace)) + rparen) |
|
||||
(quotes + pp.ZeroOrMore(pp.Word(string.whitespace)) + quotes)).set_debug(False).set_name('empty')
|
||||
|
||||
# root prompt definition
|
||||
prompt = (pp.Group(pp.OneOrMore(prompt_part) | empty) + pp.StringEnd()) \
|
||||
|
@ -174,7 +174,7 @@ class PromptParserTestCase(unittest.TestCase):
|
||||
CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]),
|
||||
parse_prompt('a forest landscape "".swap("in winter")'))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
|
||||
CrossAttentionControlSubstitute([Fragment(' ',1)], [Fragment('in winter',1)])])]),
|
||||
CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]),
|
||||
parse_prompt('a forest landscape " ".swap("in winter")'))
|
||||
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
|
||||
@ -184,7 +184,7 @@ class PromptParserTestCase(unittest.TestCase):
|
||||
CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]),
|
||||
parse_prompt('a forest landscape "in winter".swap()'))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
|
||||
CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment(' ',1)])])]),
|
||||
CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]),
|
||||
parse_prompt('a forest landscape "in winter".swap(" ")'))
|
||||
|
||||
def test_cross_attention_control_with_attention(self):
|
||||
@ -201,8 +201,21 @@ class PromptParserTestCase(unittest.TestCase):
|
||||
Fragment(',', 1), Fragment('fire', 2.0)])])
|
||||
self.assertEqual(flames_to_trees_fire, parse_prompt('"0.5(fire 0.5(flames))".swap("0.7(trees) houses"), 2.0(fire)'))
|
||||
|
||||
def test_single(self):
|
||||
print(parse_prompt('fire (trees and houses).swap("flames")'))
|
||||
|
||||
def make_basic_conjunction(self, strings: list[str]):
|
||||
fragments = [Fragment(x) for x in strings]
|
||||
return Conjunction([FlattenedPrompt(fragments)])
|
||||
|
||||
def make_weighted_conjunction(self, weighted_strings: list[tuple[str,float]]):
|
||||
fragments = [Fragment(x, w) for x,w in weighted_strings]
|
||||
return Conjunction([FlattenedPrompt(fragments)])
|
||||
|
||||
|
||||
def test_escaping(self):
|
||||
self.assertEqual(self.make_basic_conjunction(['mountain \(man\)']),parse_prompt('mountain \(man\)'))
|
||||
self.assertEqual(self.make_basic_conjunction(['mountain (\(man)\)']),parse_prompt('mountain (\(man)\)'))
|
||||
self.assertEqual(self.make_basic_conjunction(['mountain (\(man\))']),parse_prompt('mountain (\(man\))'))
|
||||
#self.assertEqual(self.make_weighted_conjunction([('mountain', 1), ('\(man\)', 1.1)]),parse_prompt('mountain +(\(man\))'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
Reference in New Issue
Block a user