diff --git a/ldm/invoke/prompt_parser.py b/ldm/invoke/prompt_parser.py index 43eea8736f..3dbcc1bb4b 100644 --- a/ldm/invoke/prompt_parser.py +++ b/ldm/invoke/prompt_parser.py @@ -143,7 +143,7 @@ class CrossAttentionControlSubstitute(CrossAttentionControlledFragment): ]) """ def __init__(self, original: list, edited: list, options: dict=None): - self.original = original + self.original = original if len(original)>0 else [Fragment('')] self.edited = edited if len(edited)>0 else [Fragment('')] default_options = { diff --git a/tests/test_prompt_parser.py b/tests/test_prompt_parser.py index 6884b62748..0c9bbf91f9 100644 --- a/tests/test_prompt_parser.py +++ b/tests/test_prompt_parser.py @@ -272,6 +272,9 @@ class PromptParserTestCase(unittest.TestCase): self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1), CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]), parse_prompt('a forest landscape "".swap("in winter")')) + self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1), + CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]), + parse_prompt('a forest landscape ().swap(in winter)')) self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1), CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]), parse_prompt('a forest landscape " ".swap("in winter")'))