fix prompt handling in conditioning.py

This commit is contained in:
Damian at mba
2022-10-20 21:41:32 +02:00
parent 3f13dd3ae8
commit da88097aba
3 changed files with 38 additions and 21 deletions

View File

@ -156,6 +156,18 @@ class PromptParserTestCase(unittest.TestCase):
parse_prompt('("fire ++(flames)", "mountain 2(man)").blend(1,1)'))
def test_cross_attention_control(self):
self.assertEqual(Conjunction([
FlattenedPrompt([Fragment('a', 1),
CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)]),
Fragment('eating a hotdog', 1)])]), parse_prompt("a \"cat\".swap(dog) eating a hotdog"))
self.assertEqual(Conjunction([
FlattenedPrompt([Fragment('a', 1),
CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)]),
Fragment('eating a hotdog', 1)])]), parse_prompt("a cat.swap(dog) eating a hotdog"))
fire_flames_to_trees = Conjunction([FlattenedPrompt([('fire', 1.0), \
CrossAttentionControlSubstitute([Fragment('flames', 1)], [Fragment('trees', 1)])])])
self.assertEqual(fire_flames_to_trees, parse_prompt('fire "flames".swap(trees)'))