diff --git a/ldm/invoke/prompt_parser.py b/ldm/invoke/prompt_parser.py index 39138e5364..f5b369bc48 100644 --- a/ldm/invoke/prompt_parser.py +++ b/ldm/invoke/prompt_parser.py @@ -118,16 +118,27 @@ class CrossAttentionControlSubstitute(CrossAttentionControlledFragment): Fragment('sitting on a car') ]) """ - def __init__(self, original: Union[Fragment, list], edited: Union[Fragment, list]): + def __init__(self, original: Union[Fragment, list], edited: Union[Fragment, list], options: dict=None): self.original = original self.edited = edited + default_options = { + 's_start': 0.0, + 's_end': 1.0, + 't_start': 0.0, + 't_end': 1.0 + } + merged_options = default_options + if options is not None: + merged_options.update(options) + self.options = merged_options def __repr__(self): - return f"CrossAttentionControlSubstitute:({self.original}->{self.edited})" + return f"CrossAttentionControlSubstitute:({self.original}->{self.edited} ({self.options})" def __eq__(self, other): return type(other) is CrossAttentionControlSubstitute \ and other.original == self.original \ - and other.edited == self.edited + and other.edited == self.edited \ + and other.options == self.options class CrossAttentionControlAppend(CrossAttentionControlledFragment): @@ -239,7 +250,7 @@ class PromptParser(): if type(x) is CrossAttentionControlSubstitute: original_fused = fuse_fragments(x.original) edited_fused = fuse_fragments(x.edited) - result.append(CrossAttentionControlSubstitute(original_fused, edited_fused)) + result.append(CrossAttentionControlSubstitute(original_fused, edited_fused, options=x.options)) else: last_weight = result[-1].weight \ if (len(result) > 0 and not issubclass(type(result[-1]), CrossAttentionControlledFragment)) \ @@ -269,7 +280,7 @@ class PromptParser(): elif type(node) is CrossAttentionControlSubstitute: original = flatten_internal(node.original, weight_scale, [], prefix + ' CAo ') edited = flatten_internal(node.edited, weight_scale, [], prefix + ' CAe ') - results += [CrossAttentionControlSubstitute(original, edited)] + results += [CrossAttentionControlSubstitute(original, edited, options=node.options)] elif type(node) is Blend: flattened_subprompts = [] #print(" flattening blend with prompts", node.prompts, "weights", node.weights) @@ -306,6 +317,7 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float) lparen = pp.Literal("(").suppress() rparen = pp.Literal(")").suppress() quotes = pp.Literal('"').suppress() + comma = pp.Literal(",").suppress() # accepts int or float notation, always maps to float number = pp.pyparsing_common.real | \ @@ -443,7 +455,18 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float) parenthesized_fragment.set_debug(debug_cross_attention_control), pp.Word(pp.printables, exclude_chars=string.whitespace + '.').set_parse_action(make_text_fragment) + pp.FollowedBy(".swap") ]) - edited_fragment = parenthesized_fragment + # support keyword=number arguments + cross_attention_option_keyword = pp.Or([pp.Keyword("s_start"), pp.Keyword("s_end"), pp.Keyword("t_start"), pp.Keyword("t_end")]) + cross_attention_option = pp.Group(cross_attention_option_keyword + pp.Literal("=").suppress() + number) + edited_fragment = pp.MatchFirst([ + lparen + + (quoted_fragment | + pp.Group(pp.Word(pp.printables, exclude_chars=string.whitespace + ',').set_parse_action(make_text_fragment)) + ) + + pp.Dict(pp.OneOrMore(comma + cross_attention_option)) + + rparen, + parenthesized_fragment + ]) cross_attention_substitute << original_fragment + pp.Literal(".swap").suppress() + edited_fragment original_fragment.set_name('original_fragment').set_debug(debug_cross_attention_control) @@ -451,9 +474,10 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float) cross_attention_substitute.set_name('cross_attention_substitute').set_debug(debug_cross_attention_control) def make_cross_attention_substitute(x): - #print("making cacs for", x) - cacs = CrossAttentionControlSubstitute(x[0], x[1]) - #print("made", cacs) + print("making cacs for", x[0], "->", x[1], "with options", x.as_dict()) + #if len(x>2): + cacs = CrossAttentionControlSubstitute(x[0], x[1], options=x.as_dict()) + print("made", cacs) return cacs cross_attention_substitute.set_parse_action(make_cross_attention_substitute) diff --git a/tests/test_prompt_parser.py b/tests/test_prompt_parser.py index 84971fcc52..02644012d8 100644 --- a/tests/test_prompt_parser.py +++ b/tests/test_prompt_parser.py @@ -247,7 +247,22 @@ class PromptParserTestCase(unittest.TestCase): Fragment(',', 1), Fragment('fire', 2.0)])]) self.assertEqual(flames_to_trees_fire, parse_prompt('"0.5(fire 0.5(flames))".swap("0.7(trees) houses"), 2.0(fire)')) - + def test_cross_attention_control_options(self): + self.assertEqual(Conjunction([ + FlattenedPrompt([Fragment('a', 1), + CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)], options={'s_start':0.1}), + Fragment('eating a hotdog', 1)])]), + parse_prompt("a \"cat\".swap(dog, s_start=0.1) eating a hotdog")) + self.assertEqual(Conjunction([ + FlattenedPrompt([Fragment('a', 1), + CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)], options={'t_start':0.1}), + Fragment('eating a hotdog', 1)])]), + parse_prompt("a \"cat\".swap(dog, t_start=0.1) eating a hotdog")) + self.assertEqual(Conjunction([ + FlattenedPrompt([Fragment('a', 1), + CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)], options={'s_start': 20.0, 't_start':0.1}), + Fragment('eating a hotdog', 1)])]), + parse_prompt("a \"cat\".swap(dog, t_start=0.1, s_start=20) eating a hotdog")) def test_escaping(self):