mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix attention weight inside .swap()
This commit is contained in:
parent
f73d349dfe
commit
e20108878c
@ -129,7 +129,7 @@ class CrossAttentionControlSubstitute(CrossAttentionControlledFragment):
|
|||||||
|
|
||||||
default_options = {
|
default_options = {
|
||||||
's_start': 0.0,
|
's_start': 0.0,
|
||||||
's_end': 0.206, # ~= shape_freedom=0.5
|
's_end': 0.2062994740159002, # ~= shape_freedom=0.5
|
||||||
't_start': 0.0,
|
't_start': 0.0,
|
||||||
't_end': 1.0
|
't_end': 1.0
|
||||||
}
|
}
|
||||||
@ -145,7 +145,7 @@ class CrossAttentionControlSubstitute(CrossAttentionControlledFragment):
|
|||||||
# so for shape_freedom = 0.5 we probably want s_end to be 0.2
|
# so for shape_freedom = 0.5 we probably want s_end to be 0.2
|
||||||
# -> cube root and subtract from 1.0
|
# -> cube root and subtract from 1.0
|
||||||
merged_options['s_end'] = 1.0 - shape_freedom ** (1. / 3.)
|
merged_options['s_end'] = 1.0 - shape_freedom ** (1. / 3.)
|
||||||
print('converted shape_freedom argument to', merged_options)
|
#print('converted shape_freedom argument to', merged_options)
|
||||||
merged_options.update(options)
|
merged_options.update(options)
|
||||||
|
|
||||||
self.options = merged_options
|
self.options = merged_options
|
||||||
@ -514,10 +514,11 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float)
|
|||||||
|
|
||||||
# cross attention control
|
# cross attention control
|
||||||
debug_cross_attention_control = False
|
debug_cross_attention_control = False
|
||||||
original_fragment = pp.Or([empty_string.set_debug(debug_cross_attention_control),
|
original_fragment = pp.MatchFirst([
|
||||||
quoted_fragment.set_debug(debug_cross_attention_control),
|
quoted_fragment.set_debug(debug_cross_attention_control),
|
||||||
parenthesized_fragment.set_debug(debug_cross_attention_control),
|
parenthesized_fragment.set_debug(debug_cross_attention_control),
|
||||||
pp.Word(pp.printables, exclude_chars=string.whitespace + '.').set_parse_action(make_text_fragment) + pp.FollowedBy(".swap")
|
pp.Word(pp.printables, exclude_chars=string.whitespace + '.').set_parse_action(make_text_fragment) + pp.FollowedBy(".swap"),
|
||||||
|
empty_string.set_debug(debug_cross_attention_control),
|
||||||
])
|
])
|
||||||
# support keyword=number arguments
|
# support keyword=number arguments
|
||||||
cross_attention_option_keyword = pp.Or([pp.Keyword("s_start"), pp.Keyword("s_end"), pp.Keyword("t_start"), pp.Keyword("t_end"), pp.Keyword("shape_freedom")])
|
cross_attention_option_keyword = pp.Or([pp.Keyword("s_start"), pp.Keyword("s_end"), pp.Keyword("t_start"), pp.Keyword("t_end"), pp.Keyword("shape_freedom")])
|
||||||
@ -525,8 +526,8 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float)
|
|||||||
edited_fragment = pp.MatchFirst([
|
edited_fragment = pp.MatchFirst([
|
||||||
(lparen + rparen).set_parse_action(lambda x: Fragment('')),
|
(lparen + rparen).set_parse_action(lambda x: Fragment('')),
|
||||||
lparen +
|
lparen +
|
||||||
(quoted_fragment |
|
(quoted_fragment | attention |
|
||||||
pp.Group(pp.ZeroOrMore(pp.Word(pp.printables, exclude_chars=string.whitespace + ',').set_parse_action(make_text_fragment)))
|
pp.Group(pp.ZeroOrMore(build_escaped_word_parser_charbychar(',)').set_parse_action(make_text_fragment)))
|
||||||
) +
|
) +
|
||||||
pp.Dict(pp.ZeroOrMore(comma + cross_attention_option)) +
|
pp.Dict(pp.ZeroOrMore(comma + cross_attention_option)) +
|
||||||
rparen,
|
rparen,
|
||||||
|
@ -250,6 +250,33 @@ class PromptParserTestCase(unittest.TestCase):
|
|||||||
Fragment(',', 1), Fragment('fire', 2.0)])])
|
Fragment(',', 1), Fragment('fire', 2.0)])])
|
||||||
self.assertEqual(flames_to_trees_fire, parse_prompt('"(fire (flames)0.5)0.5".swap("(trees)0.7 houses"), (fire)2.0'))
|
self.assertEqual(flames_to_trees_fire, parse_prompt('"(fire (flames)0.5)0.5".swap("(trees)0.7 houses"), (fire)2.0'))
|
||||||
|
|
||||||
|
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1),
|
||||||
|
CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]),
|
||||||
|
Fragment('eating a', 1),
|
||||||
|
CrossAttentionControlSubstitute([Fragment('hotdog',1)], [Fragment('hotdog', pow(1.1,4))])
|
||||||
|
])]),
|
||||||
|
parse_prompt("a cat.swap(dog) eating a hotdog.swap(hotdog++++, shape_freedom=0.5)"))
|
||||||
|
|
||||||
|
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1),
|
||||||
|
CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]),
|
||||||
|
Fragment('eating a', 1),
|
||||||
|
CrossAttentionControlSubstitute([Fragment('hotdog',1)], [Fragment('hotdog', pow(1.1,4))])
|
||||||
|
])]),
|
||||||
|
parse_prompt("a cat.swap(dog) eating a hotdog.swap(\"hotdog++++\", shape_freedom=0.5)"))
|
||||||
|
|
||||||
|
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1),
|
||||||
|
CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]),
|
||||||
|
Fragment('eating a', 1),
|
||||||
|
CrossAttentionControlSubstitute([Fragment('hotdog',1)], [Fragment('h(o)tdog', pow(1.1,4))])
|
||||||
|
])]),
|
||||||
|
parse_prompt("a cat.swap(dog) eating a hotdog.swap(h\(o\)tdog++++, shape_freedom=0.5)"))
|
||||||
|
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1),
|
||||||
|
CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]),
|
||||||
|
Fragment('eating a', 1),
|
||||||
|
CrossAttentionControlSubstitute([Fragment('hotdog',1)], [Fragment('h(o)tdog', pow(1.1,4))])
|
||||||
|
])]),
|
||||||
|
parse_prompt("a cat.swap(dog) eating a hotdog.swap(\"h\(o\)tdog++++\", shape_freedom=0.5)"))
|
||||||
|
|
||||||
def test_cross_attention_control_options(self):
|
def test_cross_attention_control_options(self):
|
||||||
self.assertEqual(Conjunction([
|
self.assertEqual(Conjunction([
|
||||||
FlattenedPrompt([Fragment('a', 1),
|
FlattenedPrompt([Fragment('a', 1),
|
||||||
|
Loading…
Reference in New Issue
Block a user