mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
parsing CrossAttentionControlSubstitute options works
This commit is contained in:
parent
cdb664f6e5
commit
ee7d4d712a
@ -118,16 +118,27 @@ class CrossAttentionControlSubstitute(CrossAttentionControlledFragment):
|
|||||||
Fragment('sitting on a car')
|
Fragment('sitting on a car')
|
||||||
])
|
])
|
||||||
"""
|
"""
|
||||||
def __init__(self, original: Union[Fragment, list], edited: Union[Fragment, list]):
|
def __init__(self, original: Union[Fragment, list], edited: Union[Fragment, list], options: dict=None):
|
||||||
self.original = original
|
self.original = original
|
||||||
self.edited = edited
|
self.edited = edited
|
||||||
|
default_options = {
|
||||||
|
's_start': 0.0,
|
||||||
|
's_end': 1.0,
|
||||||
|
't_start': 0.0,
|
||||||
|
't_end': 1.0
|
||||||
|
}
|
||||||
|
merged_options = default_options
|
||||||
|
if options is not None:
|
||||||
|
merged_options.update(options)
|
||||||
|
self.options = merged_options
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"CrossAttentionControlSubstitute:({self.original}->{self.edited})"
|
return f"CrossAttentionControlSubstitute:({self.original}->{self.edited} ({self.options})"
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return type(other) is CrossAttentionControlSubstitute \
|
return type(other) is CrossAttentionControlSubstitute \
|
||||||
and other.original == self.original \
|
and other.original == self.original \
|
||||||
and other.edited == self.edited
|
and other.edited == self.edited \
|
||||||
|
and other.options == self.options
|
||||||
|
|
||||||
|
|
||||||
class CrossAttentionControlAppend(CrossAttentionControlledFragment):
|
class CrossAttentionControlAppend(CrossAttentionControlledFragment):
|
||||||
@ -239,7 +250,7 @@ class PromptParser():
|
|||||||
if type(x) is CrossAttentionControlSubstitute:
|
if type(x) is CrossAttentionControlSubstitute:
|
||||||
original_fused = fuse_fragments(x.original)
|
original_fused = fuse_fragments(x.original)
|
||||||
edited_fused = fuse_fragments(x.edited)
|
edited_fused = fuse_fragments(x.edited)
|
||||||
result.append(CrossAttentionControlSubstitute(original_fused, edited_fused))
|
result.append(CrossAttentionControlSubstitute(original_fused, edited_fused, options=x.options))
|
||||||
else:
|
else:
|
||||||
last_weight = result[-1].weight \
|
last_weight = result[-1].weight \
|
||||||
if (len(result) > 0 and not issubclass(type(result[-1]), CrossAttentionControlledFragment)) \
|
if (len(result) > 0 and not issubclass(type(result[-1]), CrossAttentionControlledFragment)) \
|
||||||
@ -269,7 +280,7 @@ class PromptParser():
|
|||||||
elif type(node) is CrossAttentionControlSubstitute:
|
elif type(node) is CrossAttentionControlSubstitute:
|
||||||
original = flatten_internal(node.original, weight_scale, [], prefix + ' CAo ')
|
original = flatten_internal(node.original, weight_scale, [], prefix + ' CAo ')
|
||||||
edited = flatten_internal(node.edited, weight_scale, [], prefix + ' CAe ')
|
edited = flatten_internal(node.edited, weight_scale, [], prefix + ' CAe ')
|
||||||
results += [CrossAttentionControlSubstitute(original, edited)]
|
results += [CrossAttentionControlSubstitute(original, edited, options=node.options)]
|
||||||
elif type(node) is Blend:
|
elif type(node) is Blend:
|
||||||
flattened_subprompts = []
|
flattened_subprompts = []
|
||||||
#print(" flattening blend with prompts", node.prompts, "weights", node.weights)
|
#print(" flattening blend with prompts", node.prompts, "weights", node.weights)
|
||||||
@ -306,6 +317,7 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float)
|
|||||||
lparen = pp.Literal("(").suppress()
|
lparen = pp.Literal("(").suppress()
|
||||||
rparen = pp.Literal(")").suppress()
|
rparen = pp.Literal(")").suppress()
|
||||||
quotes = pp.Literal('"').suppress()
|
quotes = pp.Literal('"').suppress()
|
||||||
|
comma = pp.Literal(",").suppress()
|
||||||
|
|
||||||
# accepts int or float notation, always maps to float
|
# accepts int or float notation, always maps to float
|
||||||
number = pp.pyparsing_common.real | \
|
number = pp.pyparsing_common.real | \
|
||||||
@ -443,7 +455,18 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float)
|
|||||||
parenthesized_fragment.set_debug(debug_cross_attention_control),
|
parenthesized_fragment.set_debug(debug_cross_attention_control),
|
||||||
pp.Word(pp.printables, exclude_chars=string.whitespace + '.').set_parse_action(make_text_fragment) + pp.FollowedBy(".swap")
|
pp.Word(pp.printables, exclude_chars=string.whitespace + '.').set_parse_action(make_text_fragment) + pp.FollowedBy(".swap")
|
||||||
])
|
])
|
||||||
edited_fragment = parenthesized_fragment
|
# support keyword=number arguments
|
||||||
|
cross_attention_option_keyword = pp.Or([pp.Keyword("s_start"), pp.Keyword("s_end"), pp.Keyword("t_start"), pp.Keyword("t_end")])
|
||||||
|
cross_attention_option = pp.Group(cross_attention_option_keyword + pp.Literal("=").suppress() + number)
|
||||||
|
edited_fragment = pp.MatchFirst([
|
||||||
|
lparen +
|
||||||
|
(quoted_fragment |
|
||||||
|
pp.Group(pp.Word(pp.printables, exclude_chars=string.whitespace + ',').set_parse_action(make_text_fragment))
|
||||||
|
) +
|
||||||
|
pp.Dict(pp.OneOrMore(comma + cross_attention_option)) +
|
||||||
|
rparen,
|
||||||
|
parenthesized_fragment
|
||||||
|
])
|
||||||
cross_attention_substitute << original_fragment + pp.Literal(".swap").suppress() + edited_fragment
|
cross_attention_substitute << original_fragment + pp.Literal(".swap").suppress() + edited_fragment
|
||||||
|
|
||||||
original_fragment.set_name('original_fragment').set_debug(debug_cross_attention_control)
|
original_fragment.set_name('original_fragment').set_debug(debug_cross_attention_control)
|
||||||
@ -451,9 +474,10 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float)
|
|||||||
cross_attention_substitute.set_name('cross_attention_substitute').set_debug(debug_cross_attention_control)
|
cross_attention_substitute.set_name('cross_attention_substitute').set_debug(debug_cross_attention_control)
|
||||||
|
|
||||||
def make_cross_attention_substitute(x):
|
def make_cross_attention_substitute(x):
|
||||||
#print("making cacs for", x)
|
print("making cacs for", x[0], "->", x[1], "with options", x.as_dict())
|
||||||
cacs = CrossAttentionControlSubstitute(x[0], x[1])
|
#if len(x>2):
|
||||||
#print("made", cacs)
|
cacs = CrossAttentionControlSubstitute(x[0], x[1], options=x.as_dict())
|
||||||
|
print("made", cacs)
|
||||||
return cacs
|
return cacs
|
||||||
cross_attention_substitute.set_parse_action(make_cross_attention_substitute)
|
cross_attention_substitute.set_parse_action(make_cross_attention_substitute)
|
||||||
|
|
||||||
|
@ -247,7 +247,22 @@ class PromptParserTestCase(unittest.TestCase):
|
|||||||
Fragment(',', 1), Fragment('fire', 2.0)])])
|
Fragment(',', 1), Fragment('fire', 2.0)])])
|
||||||
self.assertEqual(flames_to_trees_fire, parse_prompt('"0.5(fire 0.5(flames))".swap("0.7(trees) houses"), 2.0(fire)'))
|
self.assertEqual(flames_to_trees_fire, parse_prompt('"0.5(fire 0.5(flames))".swap("0.7(trees) houses"), 2.0(fire)'))
|
||||||
|
|
||||||
|
def test_cross_attention_control_options(self):
|
||||||
|
self.assertEqual(Conjunction([
|
||||||
|
FlattenedPrompt([Fragment('a', 1),
|
||||||
|
CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)], options={'s_start':0.1}),
|
||||||
|
Fragment('eating a hotdog', 1)])]),
|
||||||
|
parse_prompt("a \"cat\".swap(dog, s_start=0.1) eating a hotdog"))
|
||||||
|
self.assertEqual(Conjunction([
|
||||||
|
FlattenedPrompt([Fragment('a', 1),
|
||||||
|
CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)], options={'t_start':0.1}),
|
||||||
|
Fragment('eating a hotdog', 1)])]),
|
||||||
|
parse_prompt("a \"cat\".swap(dog, t_start=0.1) eating a hotdog"))
|
||||||
|
self.assertEqual(Conjunction([
|
||||||
|
FlattenedPrompt([Fragment('a', 1),
|
||||||
|
CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)], options={'s_start': 20.0, 't_start':0.1}),
|
||||||
|
Fragment('eating a hotdog', 1)])]),
|
||||||
|
parse_prompt("a \"cat\".swap(dog, t_start=0.1, s_start=20) eating a hotdog"))
|
||||||
|
|
||||||
def test_escaping(self):
|
def test_escaping(self):
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user