mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix issue with hot-dog, improve () suppression
This commit is contained in:
parent
582e19056a
commit
135c62f1a4
@ -358,10 +358,11 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float)
|
|||||||
quoted_fragment = pp.Forward()
|
quoted_fragment = pp.Forward()
|
||||||
parenthesized_fragment = pp.Forward()
|
parenthesized_fragment = pp.Forward()
|
||||||
cross_attention_substitute = pp.Forward()
|
cross_attention_substitute = pp.Forward()
|
||||||
prompt_part = pp.Forward()
|
|
||||||
|
|
||||||
def make_text_fragment(x):
|
def make_text_fragment(x):
|
||||||
#print("### making fragment for", x)
|
#print("### making fragment for", x)
|
||||||
|
if type(x[0]) is Fragment:
|
||||||
|
assert(False)
|
||||||
if type(x) is str:
|
if type(x) is str:
|
||||||
return Fragment(x)
|
return Fragment(x)
|
||||||
elif type(x) is pp.ParseResults or type(x) is list:
|
elif type(x) is pp.ParseResults or type(x) is list:
|
||||||
@ -396,8 +397,10 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float)
|
|||||||
|
|
||||||
|
|
||||||
def parse_fragment_str(x, in_quotes: bool=False, in_parens: bool=False):
|
def parse_fragment_str(x, in_quotes: bool=False, in_parens: bool=False):
|
||||||
#print(f"parsing fragment string \"{x}\"")
|
#print(f"parsing fragment string for {x}")
|
||||||
fragment_string = x[0]
|
fragment_string = x[0]
|
||||||
|
#print(f"ppparsing fragment string \"{fragment_string}\"")
|
||||||
|
|
||||||
if len(fragment_string.strip()) == 0:
|
if len(fragment_string.strip()) == 0:
|
||||||
return Fragment('')
|
return Fragment('')
|
||||||
|
|
||||||
@ -406,13 +409,16 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float)
|
|||||||
fragment_string = fragment_string.replace('"', '\\"')
|
fragment_string = fragment_string.replace('"', '\\"')
|
||||||
|
|
||||||
#fragment_parser = pp.Group(pp.OneOrMore(attention | cross_attention_substitute | (greedy_word.set_parse_action(make_text_fragment))))
|
#fragment_parser = pp.Group(pp.OneOrMore(attention | cross_attention_substitute | (greedy_word.set_parse_action(make_text_fragment))))
|
||||||
result = pp.Group(pp.MatchFirst([
|
try:
|
||||||
pp.OneOrMore(prompt_part | quoted_fragment),
|
result = pp.Group(pp.MatchFirst([
|
||||||
pp.Empty().set_parse_action(make_text_fragment) + pp.StringEnd()
|
pp.OneOrMore(quoted_fragment | attention | unquoted_word).set_name('pf_str_qfuq'),
|
||||||
])).set_name('rr').set_debug(False).parse_string(fragment_string)
|
pp.Empty().set_parse_action(make_text_fragment) + pp.StringEnd()
|
||||||
#result = (pp.OneOrMore(attention | unquoted_word) + pp.StringEnd()).parse_string(x[0])
|
])).set_name('blend-result').set_debug(False).parse_string(fragment_string)
|
||||||
#print("parsed to", result)
|
#print("parsed to", result)
|
||||||
return result
|
return result
|
||||||
|
except pp.ParseException as e:
|
||||||
|
print("parse_fragment_str couldn't parse prompt string:", e)
|
||||||
|
raise
|
||||||
|
|
||||||
quoted_fragment << pp.QuotedString(quote_char='"', esc_char=None, esc_quote='\\"')
|
quoted_fragment << pp.QuotedString(quote_char='"', esc_char=None, esc_quote='\\"')
|
||||||
quoted_fragment.set_parse_action(lambda x: parse_fragment_str(x, in_quotes=True)).set_name('quoted_fragment')
|
quoted_fragment.set_parse_action(lambda x: parse_fragment_str(x, in_quotes=True)).set_name('quoted_fragment')
|
||||||
@ -422,14 +428,21 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float)
|
|||||||
escaped_rparen = pp.Literal('\\)')#.set_parse_action(lambda x: ')')
|
escaped_rparen = pp.Literal('\\)')#.set_parse_action(lambda x: ')')
|
||||||
escaped_backslash = pp.Literal('\\\\')#.set_parse_action(lambda x: '"')
|
escaped_backslash = pp.Literal('\\\\')#.set_parse_action(lambda x: '"')
|
||||||
|
|
||||||
|
empty = (
|
||||||
|
(lparen + pp.ZeroOrMore(pp.Word(string.whitespace)) + rparen) |
|
||||||
|
(quotes + pp.ZeroOrMore(pp.Word(string.whitespace)) + quotes)).set_debug(False).set_name('empty')
|
||||||
|
|
||||||
|
|
||||||
def not_ends_with_swap(x):
|
def not_ends_with_swap(x):
|
||||||
#print("trying to match:", x)
|
#print("trying to match:", x)
|
||||||
return not x[0].endswith('.swap')
|
return not x[0].endswith('.swap')
|
||||||
|
|
||||||
unquoted_fragment = pp.Combine(pp.OneOrMore(
|
unquoted_word = pp.Combine(pp.OneOrMore(
|
||||||
escaped_rparen | escaped_lparen | escaped_quote | escaped_backslash |
|
escaped_rparen | escaped_lparen | escaped_quote | escaped_backslash |
|
||||||
pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"()')))
|
(pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"()') + (pp.NotAny(pp.Word('+') | pp.Word('-'))))
|
||||||
unquoted_fragment.set_parse_action(make_text_fragment).set_name('unquoted_fragment').set_debug(False)
|
))
|
||||||
|
|
||||||
|
unquoted_word.set_parse_action(make_text_fragment).set_name('unquoted_word').set_debug(False)
|
||||||
#print(unquoted_fragment.parse_string("cat.swap(dog)"))
|
#print(unquoted_fragment.parse_string("cat.swap(dog)"))
|
||||||
|
|
||||||
parenthesized_fragment << pp.Or([
|
parenthesized_fragment << pp.Or([
|
||||||
@ -510,15 +523,16 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float)
|
|||||||
cross_attention_option_keyword = pp.Or([pp.Keyword("s_start"), pp.Keyword("s_end"), pp.Keyword("t_start"), pp.Keyword("t_end"), pp.Keyword("shape_freedom")])
|
cross_attention_option_keyword = pp.Or([pp.Keyword("s_start"), pp.Keyword("s_end"), pp.Keyword("t_start"), pp.Keyword("t_end"), pp.Keyword("shape_freedom")])
|
||||||
cross_attention_option = pp.Group(cross_attention_option_keyword + pp.Literal("=").suppress() + number)
|
cross_attention_option = pp.Group(cross_attention_option_keyword + pp.Literal("=").suppress() + number)
|
||||||
edited_fragment = pp.MatchFirst([
|
edited_fragment = pp.MatchFirst([
|
||||||
|
(lparen + rparen).set_parse_action(lambda x: Fragment('')),
|
||||||
lparen +
|
lparen +
|
||||||
(quoted_fragment |
|
(quoted_fragment |
|
||||||
pp.Group(pp.OneOrMore(pp.Word(pp.printables, exclude_chars=string.whitespace + ',').set_parse_action(make_text_fragment)))
|
pp.Group(pp.ZeroOrMore(pp.Word(pp.printables, exclude_chars=string.whitespace + ',').set_parse_action(make_text_fragment)))
|
||||||
) +
|
) +
|
||||||
pp.Dict(pp.OneOrMore(comma + cross_attention_option)) +
|
pp.Dict(pp.ZeroOrMore(comma + cross_attention_option)) +
|
||||||
rparen,
|
rparen,
|
||||||
parenthesized_fragment
|
parenthesized_fragment
|
||||||
])
|
])
|
||||||
cross_attention_substitute << original_fragment + pp.Literal(".swap").suppress() + edited_fragment
|
cross_attention_substitute << original_fragment + pp.Literal(".swap").set_debug(False).suppress() + edited_fragment
|
||||||
|
|
||||||
original_fragment.set_name('original_fragment').set_debug(debug_cross_attention_control)
|
original_fragment.set_name('original_fragment').set_debug(debug_cross_attention_control)
|
||||||
edited_fragment.set_name('edited_fragment').set_debug(debug_cross_attention_control)
|
edited_fragment.set_name('edited_fragment').set_debug(debug_cross_attention_control)
|
||||||
@ -533,24 +547,18 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float)
|
|||||||
cross_attention_substitute.set_parse_action(make_cross_attention_substitute)
|
cross_attention_substitute.set_parse_action(make_cross_attention_substitute)
|
||||||
|
|
||||||
|
|
||||||
# simple fragments of text
|
|
||||||
# use Or to match the longest
|
|
||||||
prompt_part << pp.MatchFirst([
|
|
||||||
cross_attention_substitute,
|
|
||||||
attention,
|
|
||||||
unquoted_fragment,
|
|
||||||
lparen + unquoted_fragment + rparen # matches case where user has +(term) and just deletes the +
|
|
||||||
])
|
|
||||||
prompt_part.set_debug(False)
|
|
||||||
prompt_part.set_name("prompt_part")
|
|
||||||
|
|
||||||
empty = (
|
|
||||||
(lparen + pp.ZeroOrMore(pp.Word(string.whitespace)) + rparen) |
|
|
||||||
(quotes + pp.ZeroOrMore(pp.Word(string.whitespace)) + quotes)).set_debug(False).set_name('empty')
|
|
||||||
|
|
||||||
# root prompt definition
|
# root prompt definition
|
||||||
prompt = (pp.OneOrMore(pp.Or([prompt_part, quoted_fragment, empty])) + pp.StringEnd()) \
|
debug_root_prompt = False
|
||||||
.set_parse_action(lambda x: Prompt(x))
|
prompt = (pp.OneOrMore(pp.Or([cross_attention_substitute.set_debug(debug_root_prompt),
|
||||||
|
attention.set_debug(debug_root_prompt),
|
||||||
|
quoted_fragment.set_debug(debug_root_prompt),
|
||||||
|
(lparen + (pp.ZeroOrMore(unquoted_word | pp.White().suppress()).leave_whitespace()) + rparen).set_name('parenthesized-uqw').set_debug(debug_root_prompt),
|
||||||
|
unquoted_word.set_debug(debug_root_prompt),
|
||||||
|
empty.set_parse_action(make_text_fragment).set_debug(debug_root_prompt)])
|
||||||
|
) + pp.StringEnd()) \
|
||||||
|
.set_name('prompt') \
|
||||||
|
.set_parse_action(lambda x: Prompt(x)) \
|
||||||
|
.set_debug(debug_root_prompt)
|
||||||
|
|
||||||
#print("parsing test:", prompt.parse_string("spaced eyes--"))
|
#print("parsing test:", prompt.parse_string("spaced eyes--"))
|
||||||
#print("parsing test:", prompt.parse_string("eyes--"))
|
#print("parsing test:", prompt.parse_string("eyes--"))
|
||||||
@ -567,7 +575,7 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float)
|
|||||||
if len(x_unquoted.strip()) == 0:
|
if len(x_unquoted.strip()) == 0:
|
||||||
# print(' b : just an empty string')
|
# print(' b : just an empty string')
|
||||||
return Prompt([Fragment('')])
|
return Prompt([Fragment('')])
|
||||||
# print(' b parsing ', c_unquoted)
|
#print(f' b parsing \'{x_unquoted}\'')
|
||||||
x_parsed = prompt.parse_string(x_unquoted)
|
x_parsed = prompt.parse_string(x_unquoted)
|
||||||
#print(" quoted prompt was parsed to", type(x_parsed),":", x_parsed)
|
#print(" quoted prompt was parsed to", type(x_parsed),":", x_parsed)
|
||||||
return x_parsed[0]
|
return x_parsed[0]
|
||||||
|
@ -32,6 +32,7 @@ class PromptParserTestCase(unittest.TestCase):
|
|||||||
self.assertEqual(make_weighted_conjunction([("fire flames", 1)]), parse_prompt("fire flames"))
|
self.assertEqual(make_weighted_conjunction([("fire flames", 1)]), parse_prompt("fire flames"))
|
||||||
self.assertEqual(make_weighted_conjunction([("fire, flames", 1)]), parse_prompt("fire, flames"))
|
self.assertEqual(make_weighted_conjunction([("fire, flames", 1)]), parse_prompt("fire, flames"))
|
||||||
self.assertEqual(make_weighted_conjunction([("fire, flames , fire", 1)]), parse_prompt("fire, flames , fire"))
|
self.assertEqual(make_weighted_conjunction([("fire, flames , fire", 1)]), parse_prompt("fire, flames , fire"))
|
||||||
|
self.assertEqual(make_weighted_conjunction([("cat hot-dog eating", 1)]), parse_prompt("cat hot-dog eating"))
|
||||||
|
|
||||||
def test_attention(self):
|
def test_attention(self):
|
||||||
self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("(flames)0.5"))
|
self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("(flames)0.5"))
|
||||||
@ -106,10 +107,7 @@ class PromptParserTestCase(unittest.TestCase):
|
|||||||
#with self.assertRaises(pyparsing.ParseException):
|
#with self.assertRaises(pyparsing.ParseException):
|
||||||
with self.assertRaises(pyparsing.ParseException):
|
with self.assertRaises(pyparsing.ParseException):
|
||||||
parse_prompt('a badly (formed +test prompt')
|
parse_prompt('a badly (formed +test prompt')
|
||||||
with self.assertRaises(pyparsing.ParseException):
|
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a badly formed +test prompt',1)])]) , parse_prompt('a badly (formed +test )prompt'))
|
||||||
parse_prompt('a badly (formed +test )prompt')
|
|
||||||
with self.assertRaises(pyparsing.ParseException):
|
|
||||||
parse_prompt('a badly (formed +test )prompt')
|
|
||||||
with self.assertRaises(pyparsing.ParseException):
|
with self.assertRaises(pyparsing.ParseException):
|
||||||
parse_prompt('(((a badly (formed +test )prompt')
|
parse_prompt('(((a badly (formed +test )prompt')
|
||||||
with self.assertRaises(pyparsing.ParseException):
|
with self.assertRaises(pyparsing.ParseException):
|
||||||
@ -394,6 +392,9 @@ class PromptParserTestCase(unittest.TestCase):
|
|||||||
# todo handle this
|
# todo handle this
|
||||||
#self.assertEqual(make_basic_conjunction(['a badly formed +test prompt']),
|
#self.assertEqual(make_basic_conjunction(['a badly formed +test prompt']),
|
||||||
# parse_prompt('a badly formed +test prompt'))
|
# parse_prompt('a badly formed +test prompt'))
|
||||||
|
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
|
||||||
|
CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]),
|
||||||
|
parse_prompt('a forest landscape "in winter".swap()'))
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user