From da88097abac211e9769ce51c4dd8f79d6ed64e9f Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Thu, 20 Oct 2022 21:41:32 +0200 Subject: [PATCH] fix prompt handling in conditioning.py --- ldm/invoke/conditioning.py | 8 ++++---- ldm/invoke/prompt_parser.py | 39 +++++++++++++++++++++---------------- tests/test_prompt_parser.py | 12 ++++++++++++ 3 files changed, 38 insertions(+), 21 deletions(-) diff --git a/ldm/invoke/conditioning.py b/ldm/invoke/conditioning.py index fb6d8d443e..e3685db615 100644 --- a/ldm/invoke/conditioning.py +++ b/ldm/invoke/conditioning.py @@ -66,10 +66,10 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n edited_prompt = FlattenedPrompt() for fragment in flattened_prompt.children: if type(fragment) is CrossAttentionControlSubstitute: - original_prompt.append(fragment.original_fragment) - edited_prompt.append(fragment.edited_fragment) - elif type(fragment) is CrossAttentionControlAppend: - edited_prompt.append(fragment.fragment) + original_prompt.append(fragment.original) + edited_prompt.append(fragment.edited) + #elif type(fragment) is CrossAttentionControlAppend: + # edited_prompt.append(fragment.fragment) else: # regular fragment original_prompt.append(fragment) diff --git a/ldm/invoke/prompt_parser.py b/ldm/invoke/prompt_parser.py index d576d069aa..68cc102584 100644 --- a/ldm/invoke/prompt_parser.py +++ b/ldm/invoke/prompt_parser.py @@ -1,4 +1,5 @@ import string +from typing import Union import pyparsing import pyparsing as pp @@ -17,24 +18,31 @@ class Prompt(): def __eq__(self, other): return type(other) is Prompt and other.children == self.children +class BaseFragment: + pass + class FlattenedPrompt(): - def __init__(self, parts: list): + def __init__(self, parts: list=[]): # verify type correctness - parts_converted = [] + self.children = [] for part in parts: - if issubclass(type(part), BaseFragment): - parts_converted.append(part) - elif type(part) is tuple: - # upgrade tuples to Fragments - if type(part[0]) is not str or (type(part[1]) is not float and type(part[1]) is not int): - raise PromptParser.ParsingException( - f"FlattenedPrompt cannot contain {part}, only Fragments or (str, float) tuples are allowed") - parts_converted.append(Fragment(part[0], part[1])) - else: + self.append(part) + + def append(self, fragment: Union[list, BaseFragment, tuple]): + if type(fragment) is list: + for x in fragment: + self.append(x) + elif issubclass(type(fragment), BaseFragment): + self.children.append(fragment) + elif type(fragment) is tuple: + # upgrade tuples to Fragments + if type(fragment[0]) is not str or (type(fragment[1]) is not float and type(fragment[1]) is not int): raise PromptParser.ParsingException( - f"FlattenedPrompt cannot contain {part}, only Fragments or (str, float) tuples are allowed") - # all looks good - self.children = parts_converted + f"FlattenedPrompt cannot contain {fragment}, only Fragments or (str, float) tuples are allowed") + self.children.append(Fragment(fragment[0], fragment[1])) + else: + raise PromptParser.ParsingException( + f"FlattenedPrompt cannot contain {fragment}, only Fragments or (str, float) tuples are allowed") def __repr__(self): return f"FlattenedPrompt:{self.children}" @@ -42,9 +50,6 @@ class FlattenedPrompt(): return type(other) is FlattenedPrompt and other.children == self.children # abstract base class for Fragments -class BaseFragment: - pass - class Fragment(BaseFragment): def __init__(self, text: str, weight: float=1): assert(type(text) is str) diff --git a/tests/test_prompt_parser.py b/tests/test_prompt_parser.py index d053253eb6..2bfae0cb48 100644 --- a/tests/test_prompt_parser.py +++ b/tests/test_prompt_parser.py @@ -156,6 +156,18 @@ class PromptParserTestCase(unittest.TestCase): parse_prompt('("fire ++(flames)", "mountain 2(man)").blend(1,1)')) def test_cross_attention_control(self): + + self.assertEqual(Conjunction([ + FlattenedPrompt([Fragment('a', 1), + CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)]), + Fragment('eating a hotdog', 1)])]), parse_prompt("a \"cat\".swap(dog) eating a hotdog")) + + self.assertEqual(Conjunction([ + FlattenedPrompt([Fragment('a', 1), + CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)]), + Fragment('eating a hotdog', 1)])]), parse_prompt("a cat.swap(dog) eating a hotdog")) + + fire_flames_to_trees = Conjunction([FlattenedPrompt([('fire', 1.0), \ CrossAttentionControlSubstitute([Fragment('flames', 1)], [Fragment('trees', 1)])])]) self.assertEqual(fire_flames_to_trees, parse_prompt('fire "flames".swap(trees)'))