mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
be more forgiving about prompts with ((words))
This commit is contained in:
parent
30745f163d
commit
245cf606a3
@ -417,7 +417,7 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float)
|
|||||||
#print("parsed to", result)
|
#print("parsed to", result)
|
||||||
return result
|
return result
|
||||||
except pp.ParseException as e:
|
except pp.ParseException as e:
|
||||||
print("parse_fragment_str couldn't parse prompt string:", e)
|
#print("parse_fragment_str couldn't parse prompt string:", e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
quoted_fragment << pp.QuotedString(quote_char='"', esc_char=None, esc_quote='\\"')
|
quoted_fragment << pp.QuotedString(quote_char='"', esc_char=None, esc_quote='\\"')
|
||||||
@ -445,14 +445,17 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float)
|
|||||||
unquoted_word.set_parse_action(make_text_fragment).set_name('unquoted_word').set_debug(False)
|
unquoted_word.set_parse_action(make_text_fragment).set_name('unquoted_word').set_debug(False)
|
||||||
#print(unquoted_fragment.parse_string("cat.swap(dog)"))
|
#print(unquoted_fragment.parse_string("cat.swap(dog)"))
|
||||||
|
|
||||||
parenthesized_fragment << pp.Or([
|
parenthesized_fragment << (lparen +
|
||||||
(lparen + quoted_fragment.copy().set_parse_action(lambda x: parse_fragment_str(x, in_quotes=True)).set_debug(False) + rparen).set_name('-quoted_paren_internal').set_debug(False),
|
pp.Or([
|
||||||
(lparen + rparen).set_parse_action(lambda x: make_text_fragment('')).set_name('-()').set_debug(False),
|
(parenthesized_fragment),
|
||||||
(lparen + pp.Combine(pp.OneOrMore(
|
(quoted_fragment.copy().set_parse_action(lambda x: parse_fragment_str(x, in_quotes=True)).set_debug(False)).set_name('-quoted_paren_internal').set_debug(False),
|
||||||
|
(pp.Combine(pp.OneOrMore(
|
||||||
escaped_quote | escaped_lparen | escaped_rparen | escaped_backslash |
|
escaped_quote | escaped_lparen | escaped_rparen | escaped_backslash |
|
||||||
pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"()') |
|
pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"()') |
|
||||||
pp.Word(string.whitespace)
|
pp.Word(string.whitespace)
|
||||||
)).set_name('--combined').set_parse_action(lambda x: parse_fragment_str(x, in_parens=True)).set_debug(False) + rparen)]).set_name('-unquoted_paren_internal').set_debug(False)
|
)).set_name('--combined').set_parse_action(lambda x: parse_fragment_str(x, in_parens=True)).set_debug(False)),
|
||||||
|
pp.Empty()
|
||||||
|
]) + rparen)
|
||||||
parenthesized_fragment.set_name('parenthesized_fragment').set_debug(False)
|
parenthesized_fragment.set_name('parenthesized_fragment').set_debug(False)
|
||||||
|
|
||||||
debug_attention = False
|
debug_attention = False
|
||||||
@ -472,11 +475,12 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float)
|
|||||||
+ rparen + attention_with_parens_foot)
|
+ rparen + attention_with_parens_foot)
|
||||||
attention_with_parens.set_name('attention_with_parens').set_debug(debug_attention)
|
attention_with_parens.set_name('attention_with_parens').set_debug(debug_attention)
|
||||||
|
|
||||||
attention_without_parens_foot = pp.Or(pp.Word('+') | pp.Word('-')).set_name('attention_without_parens_foots')
|
attention_without_parens_foot = pp.NotAny(pp.White()) + pp.Or(pp.Word('+') | pp.Word('-')).set_name('attention_without_parens_foots')
|
||||||
attention_without_parens <<= pp.Group(
|
attention_without_parens <<= pp.Group(pp.MatchFirst([
|
||||||
(quoted_fragment.copy().set_name('attention_quoted_fragment_without_parens').set_debug(debug_attention) + attention_without_parens_foot) |
|
quoted_fragment.copy().set_name('attention_quoted_fragment_without_parens').set_debug(debug_attention) + attention_without_parens_foot,
|
||||||
pp.Combine(build_escaped_word_parser_charbychar('()+-')).set_name('attention_word_without_parens').set_debug(debug_attention)#.set_parse_action(lambda x: print('escapéd', x))
|
pp.Combine(build_escaped_word_parser_charbychar('()+-')).set_name('attention_word_without_parens').set_debug(debug_attention)#.set_parse_action(lambda x: print('escapéd', x))
|
||||||
+ attention_without_parens_foot)#.leave_whitespace()
|
+ attention_without_parens_foot#.leave_whitespace()
|
||||||
|
]))
|
||||||
attention_without_parens.set_name('attention_without_parens').set_debug(debug_attention)
|
attention_without_parens.set_name('attention_without_parens').set_debug(debug_attention)
|
||||||
|
|
||||||
|
|
||||||
@ -553,7 +557,7 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float)
|
|||||||
prompt = (pp.OneOrMore(pp.Or([cross_attention_substitute.set_debug(debug_root_prompt),
|
prompt = (pp.OneOrMore(pp.Or([cross_attention_substitute.set_debug(debug_root_prompt),
|
||||||
attention.set_debug(debug_root_prompt),
|
attention.set_debug(debug_root_prompt),
|
||||||
quoted_fragment.set_debug(debug_root_prompt),
|
quoted_fragment.set_debug(debug_root_prompt),
|
||||||
(lparen + (pp.ZeroOrMore(unquoted_word | pp.White().suppress()).leave_whitespace()) + rparen).set_name('parenthesized-uqw').set_debug(debug_root_prompt),
|
parenthesized_fragment.set_debug(debug_root_prompt),
|
||||||
unquoted_word.set_debug(debug_root_prompt),
|
unquoted_word.set_debug(debug_root_prompt),
|
||||||
empty.set_parse_action(make_text_fragment).set_debug(debug_root_prompt)])
|
empty.set_parse_action(make_text_fragment).set_debug(debug_root_prompt)])
|
||||||
) + pp.StringEnd()) \
|
) + pp.StringEnd()) \
|
||||||
|
@ -117,6 +117,15 @@ class PromptParserTestCase(unittest.TestCase):
|
|||||||
with self.assertRaises(pyparsing.ParseException):
|
with self.assertRaises(pyparsing.ParseException):
|
||||||
parse_prompt('("((a badly (formed +test ").blend(1.0)')
|
parse_prompt('("((a badly (formed +test ").blend(1.0)')
|
||||||
|
|
||||||
|
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('hamburger bun', 1)])]),
|
||||||
|
parse_prompt("hamburger ((bun))"))
|
||||||
|
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('hamburger bun', 1)])]),
|
||||||
|
parse_prompt("hamburger (bun)"))
|
||||||
|
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('hamburger kaiser roll', 1)])]),
|
||||||
|
parse_prompt("hamburger (kaiser roll)"))
|
||||||
|
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('hamburger kaiser roll', 1)])]),
|
||||||
|
parse_prompt("hamburger ((kaiser roll))"))
|
||||||
|
|
||||||
|
|
||||||
def test_blend(self):
|
def test_blend(self):
|
||||||
self.assertEqual(Conjunction(
|
self.assertEqual(Conjunction(
|
||||||
@ -284,6 +293,7 @@ class PromptParserTestCase(unittest.TestCase):
|
|||||||
])]),
|
])]),
|
||||||
parse_prompt("a cat.swap(dog) eating a hotdog.swap(h\(o\)tdog-, shape_freedom=0.5)"))
|
parse_prompt("a cat.swap(dog) eating a hotdog.swap(h\(o\)tdog-, shape_freedom=0.5)"))
|
||||||
|
|
||||||
|
|
||||||
def test_cross_attention_control_options(self):
|
def test_cross_attention_control_options(self):
|
||||||
self.assertEqual(Conjunction([
|
self.assertEqual(Conjunction([
|
||||||
FlattenedPrompt([Fragment('a', 1),
|
FlattenedPrompt([Fragment('a', 1),
|
||||||
@ -426,9 +436,9 @@ class PromptParserTestCase(unittest.TestCase):
|
|||||||
# todo handle this
|
# todo handle this
|
||||||
#self.assertEqual(make_basic_conjunction(['a badly formed +test prompt']),
|
#self.assertEqual(make_basic_conjunction(['a badly formed +test prompt']),
|
||||||
# parse_prompt('a badly formed +test prompt'))
|
# parse_prompt('a badly formed +test prompt'))
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
|
trees_and_houses_to_flames = Conjunction([FlattenedPrompt([('fire', 1.0), \
|
||||||
CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]),
|
CrossAttentionControlSubstitute([Fragment('trees and houses', 1)], [Fragment('flames',1)])])])
|
||||||
parse_prompt('a forest landscape "in winter".swap()'))
|
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire (trees and houses).swap("flames")'))
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user