fix parsing error doing eg forest ().swap(in winter)

This commit is contained in:
damian0815 2022-11-03 12:39:45 +01:00 committed by Lincoln Stein
parent 1f0c5b4cf1
commit b70420951d
2 changed files with 4 additions and 1 deletions

View File

@ -143,7 +143,7 @@ class CrossAttentionControlSubstitute(CrossAttentionControlledFragment):
]) ])
""" """
def __init__(self, original: list, edited: list, options: dict=None): def __init__(self, original: list, edited: list, options: dict=None):
self.original = original self.original = original if len(original)>0 else [Fragment('')]
self.edited = edited if len(edited)>0 else [Fragment('')] self.edited = edited if len(edited)>0 else [Fragment('')]
default_options = { default_options = {

View File

@ -272,6 +272,9 @@ class PromptParserTestCase(unittest.TestCase):
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1), self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]), CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]),
parse_prompt('a forest landscape "".swap("in winter")')) parse_prompt('a forest landscape "".swap("in winter")'))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]),
parse_prompt('a forest landscape ().swap(in winter)'))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1), self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]), CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]),
parse_prompt('a forest landscape " ".swap("in winter")')) parse_prompt('a forest landscape " ".swap("in winter")'))