fix prompt handling in conditioning.py

This commit is contained in:
Damian at mba 2022-10-20 21:41:32 +02:00
parent 3f13dd3ae8
commit da88097aba
3 changed files with 38 additions and 21 deletions

View File

@ -66,10 +66,10 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
edited_prompt = FlattenedPrompt() edited_prompt = FlattenedPrompt()
for fragment in flattened_prompt.children: for fragment in flattened_prompt.children:
if type(fragment) is CrossAttentionControlSubstitute: if type(fragment) is CrossAttentionControlSubstitute:
original_prompt.append(fragment.original_fragment) original_prompt.append(fragment.original)
edited_prompt.append(fragment.edited_fragment) edited_prompt.append(fragment.edited)
elif type(fragment) is CrossAttentionControlAppend: #elif type(fragment) is CrossAttentionControlAppend:
edited_prompt.append(fragment.fragment) # edited_prompt.append(fragment.fragment)
else: else:
# regular fragment # regular fragment
original_prompt.append(fragment) original_prompt.append(fragment)

View File

@ -1,4 +1,5 @@
import string import string
from typing import Union
import pyparsing import pyparsing
import pyparsing as pp import pyparsing as pp
@ -17,24 +18,31 @@ class Prompt():
def __eq__(self, other): def __eq__(self, other):
return type(other) is Prompt and other.children == self.children return type(other) is Prompt and other.children == self.children
class BaseFragment:
pass
class FlattenedPrompt(): class FlattenedPrompt():
def __init__(self, parts: list): def __init__(self, parts: list=[]):
# verify type correctness # verify type correctness
parts_converted = [] self.children = []
for part in parts: for part in parts:
if issubclass(type(part), BaseFragment): self.append(part)
parts_converted.append(part)
elif type(part) is tuple: def append(self, fragment: Union[list, BaseFragment, tuple]):
# upgrade tuples to Fragments if type(fragment) is list:
if type(part[0]) is not str or (type(part[1]) is not float and type(part[1]) is not int): for x in fragment:
raise PromptParser.ParsingException( self.append(x)
f"FlattenedPrompt cannot contain {part}, only Fragments or (str, float) tuples are allowed") elif issubclass(type(fragment), BaseFragment):
parts_converted.append(Fragment(part[0], part[1])) self.children.append(fragment)
else: elif type(fragment) is tuple:
# upgrade tuples to Fragments
if type(fragment[0]) is not str or (type(fragment[1]) is not float and type(fragment[1]) is not int):
raise PromptParser.ParsingException( raise PromptParser.ParsingException(
f"FlattenedPrompt cannot contain {part}, only Fragments or (str, float) tuples are allowed") f"FlattenedPrompt cannot contain {fragment}, only Fragments or (str, float) tuples are allowed")
# all looks good self.children.append(Fragment(fragment[0], fragment[1]))
self.children = parts_converted else:
raise PromptParser.ParsingException(
f"FlattenedPrompt cannot contain {fragment}, only Fragments or (str, float) tuples are allowed")
def __repr__(self): def __repr__(self):
return f"FlattenedPrompt:{self.children}" return f"FlattenedPrompt:{self.children}"
@ -42,9 +50,6 @@ class FlattenedPrompt():
return type(other) is FlattenedPrompt and other.children == self.children return type(other) is FlattenedPrompt and other.children == self.children
# abstract base class for Fragments # abstract base class for Fragments
class BaseFragment:
pass
class Fragment(BaseFragment): class Fragment(BaseFragment):
def __init__(self, text: str, weight: float=1): def __init__(self, text: str, weight: float=1):
assert(type(text) is str) assert(type(text) is str)

View File

@ -156,6 +156,18 @@ class PromptParserTestCase(unittest.TestCase):
parse_prompt('("fire ++(flames)", "mountain 2(man)").blend(1,1)')) parse_prompt('("fire ++(flames)", "mountain 2(man)").blend(1,1)'))
def test_cross_attention_control(self): def test_cross_attention_control(self):
self.assertEqual(Conjunction([
FlattenedPrompt([Fragment('a', 1),
CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)]),
Fragment('eating a hotdog', 1)])]), parse_prompt("a \"cat\".swap(dog) eating a hotdog"))
self.assertEqual(Conjunction([
FlattenedPrompt([Fragment('a', 1),
CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)]),
Fragment('eating a hotdog', 1)])]), parse_prompt("a cat.swap(dog) eating a hotdog"))
fire_flames_to_trees = Conjunction([FlattenedPrompt([('fire', 1.0), \ fire_flames_to_trees = Conjunction([FlattenedPrompt([('fire', 1.0), \
CrossAttentionControlSubstitute([Fragment('flames', 1)], [Fragment('trees', 1)])])]) CrossAttentionControlSubstitute([Fragment('flames', 1)], [Fragment('trees', 1)])])])
self.assertEqual(fire_flames_to_trees, parse_prompt('fire "flames".swap(trees)')) self.assertEqual(fire_flames_to_trees, parse_prompt('fire "flames".swap(trees)'))