mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
bring in prompt parser from fix-prompts branch
attention is parsed but ignored, blends old syntax doesn't work, conjunctions are parsed but ignored, the only part that's used here is the new .blend() syntax and cross-attention control using .swap()
This commit is contained in:
@ -1,6 +1,7 @@
|
||||
import unittest
|
||||
|
||||
from ldm.invoke.prompt_parser import PromptParser, Blend, Conjunction, FlattenedPrompt, CrossAttentionControlSubstitute
|
||||
from ldm.invoke.prompt_parser import PromptParser, Blend, Conjunction, FlattenedPrompt, CrossAttentionControlSubstitute, \
|
||||
Fragment
|
||||
|
||||
|
||||
def parse_prompt(prompt_string):
|
||||
@ -135,7 +136,7 @@ class PromptParserTestCase(unittest.TestCase):
|
||||
|
||||
def test_cross_attention_control(self):
|
||||
fire_flames_to_trees = Conjunction([FlattenedPrompt([('fire', 1.0), \
|
||||
CrossAttentionControlSubstitute('flames', 'trees')])])
|
||||
CrossAttentionControlSubstitute([Fragment('flames', 1)], [Fragment('trees', 1)])])])
|
||||
self.assertEqual(fire_flames_to_trees, parse_prompt('fire "flames".swap(trees)'))
|
||||
self.assertEqual(fire_flames_to_trees, parse_prompt('fire (flames).swap(trees)'))
|
||||
self.assertEqual(fire_flames_to_trees, parse_prompt('fire ("flames").swap(trees)'))
|
||||
@ -144,13 +145,13 @@ class PromptParserTestCase(unittest.TestCase):
|
||||
self.assertEqual(fire_flames_to_trees, parse_prompt('fire ("flames").swap("trees")'))
|
||||
|
||||
fire_flames_to_trees_and_houses = Conjunction([FlattenedPrompt([('fire', 1.0), \
|
||||
CrossAttentionControlSubstitute('flames', 'trees and houses')])])
|
||||
CrossAttentionControlSubstitute([Fragment('flames', 1)], [Fragment('trees and houses', 1)])])])
|
||||
self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire ("flames").swap("trees and houses")'))
|
||||
self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire (flames).swap("trees and houses")'))
|
||||
self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire "flames".swap("trees and houses")'))
|
||||
|
||||
trees_and_houses_to_flames = Conjunction([FlattenedPrompt([('fire', 1.0), \
|
||||
CrossAttentionControlSubstitute('trees and houses', 'flames')])])
|
||||
CrossAttentionControlSubstitute([Fragment('trees and houses', 1)], [Fragment('flames',1)])])])
|
||||
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire ("trees and houses").swap("flames")'))
|
||||
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire (trees and houses).swap("flames")'))
|
||||
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire "trees and houses".swap("flames")'))
|
||||
@ -159,14 +160,46 @@ class PromptParserTestCase(unittest.TestCase):
|
||||
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire "trees and houses".swap(flames)'))
|
||||
|
||||
flames_to_trees_fire = Conjunction([FlattenedPrompt([
|
||||
CrossAttentionControlSubstitute('flames', 'trees'),
|
||||
CrossAttentionControlSubstitute([Fragment('flames',1)], [Fragment('trees',1)]),
|
||||
(', fire', 1.0)])])
|
||||
self.assertEqual(flames_to_trees_fire, parse_prompt('"flames".swap(trees), fire'))
|
||||
self.assertEqual(flames_to_trees_fire, parse_prompt('(flames).swap(trees), fire '))
|
||||
self.assertEqual(flames_to_trees_fire, parse_prompt('("flames").swap(trees), fire '))
|
||||
self.assertEqual(flames_to_trees_fire, parse_prompt('"flames".swap("trees"), fire'))
|
||||
self.assertEqual(flames_to_trees_fire, parse_prompt('(flames).swap("trees"), fire'))
|
||||
self.assertEqual(flames_to_trees_fire, parse_prompt('("flames").swap("trees"), fire'))
|
||||
self.assertEqual(flames_to_trees_fire, parse_prompt('"flames".swap(trees), fire'))
|
||||
self.assertEqual(flames_to_trees_fire, parse_prompt('(flames).swap(trees), fire '))
|
||||
self.assertEqual(flames_to_trees_fire, parse_prompt('("flames").swap(trees), fire '))
|
||||
|
||||
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
|
||||
CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]),
|
||||
parse_prompt('a forest landscape "".swap("in winter")'))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
|
||||
CrossAttentionControlSubstitute([Fragment(' ',1)], [Fragment('in winter',1)])])]),
|
||||
parse_prompt('a forest landscape " ".swap("in winter")'))
|
||||
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
|
||||
CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]),
|
||||
parse_prompt('a forest landscape "in winter".swap("")'))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
|
||||
CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]),
|
||||
parse_prompt('a forest landscape "in winter".swap()'))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
|
||||
CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment(' ',1)])])]),
|
||||
parse_prompt('a forest landscape "in winter".swap(" ")'))
|
||||
|
||||
def test_cross_attention_control_with_attention(self):
|
||||
flames_to_trees_fire = Conjunction([FlattenedPrompt([
|
||||
CrossAttentionControlSubstitute([Fragment('flames',0.5)], [Fragment('trees',0.7)]),
|
||||
Fragment(',', 1), Fragment('fire', 2.0)])])
|
||||
self.assertEqual(flames_to_trees_fire, parse_prompt('"0.5(flames)".swap("0.7(trees)"), 2.0(fire)'))
|
||||
flames_to_trees_fire = Conjunction([FlattenedPrompt([
|
||||
CrossAttentionControlSubstitute([Fragment('fire',0.5), Fragment('flames',0.25)], [Fragment('trees',0.7)]),
|
||||
Fragment(',', 1), Fragment('fire', 2.0)])])
|
||||
self.assertEqual(flames_to_trees_fire, parse_prompt('"0.5(fire 0.5(flames))".swap("0.7(trees)"), 2.0(fire)'))
|
||||
flames_to_trees_fire = Conjunction([FlattenedPrompt([
|
||||
CrossAttentionControlSubstitute([Fragment('fire',0.5), Fragment('flames',0.25)], [Fragment('trees',0.7), Fragment('houses', 1)]),
|
||||
Fragment(',', 1), Fragment('fire', 2.0)])])
|
||||
self.assertEqual(flames_to_trees_fire, parse_prompt('"0.5(fire 0.5(flames))".swap("0.7(trees) houses"), 2.0(fire)'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Reference in New Issue
Block a user