mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix prompt handling in conditioning.py
This commit is contained in:
parent
3f13dd3ae8
commit
da88097aba
@ -66,10 +66,10 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
|
||||
edited_prompt = FlattenedPrompt()
|
||||
for fragment in flattened_prompt.children:
|
||||
if type(fragment) is CrossAttentionControlSubstitute:
|
||||
original_prompt.append(fragment.original_fragment)
|
||||
edited_prompt.append(fragment.edited_fragment)
|
||||
elif type(fragment) is CrossAttentionControlAppend:
|
||||
edited_prompt.append(fragment.fragment)
|
||||
original_prompt.append(fragment.original)
|
||||
edited_prompt.append(fragment.edited)
|
||||
#elif type(fragment) is CrossAttentionControlAppend:
|
||||
# edited_prompt.append(fragment.fragment)
|
||||
else:
|
||||
# regular fragment
|
||||
original_prompt.append(fragment)
|
||||
|
@ -1,4 +1,5 @@
|
||||
import string
|
||||
from typing import Union
|
||||
|
||||
import pyparsing
|
||||
import pyparsing as pp
|
||||
@ -17,24 +18,31 @@ class Prompt():
|
||||
def __eq__(self, other):
|
||||
return type(other) is Prompt and other.children == self.children
|
||||
|
||||
class BaseFragment:
|
||||
pass
|
||||
|
||||
class FlattenedPrompt():
|
||||
def __init__(self, parts: list):
|
||||
def __init__(self, parts: list=[]):
|
||||
# verify type correctness
|
||||
parts_converted = []
|
||||
self.children = []
|
||||
for part in parts:
|
||||
if issubclass(type(part), BaseFragment):
|
||||
parts_converted.append(part)
|
||||
elif type(part) is tuple:
|
||||
self.append(part)
|
||||
|
||||
def append(self, fragment: Union[list, BaseFragment, tuple]):
|
||||
if type(fragment) is list:
|
||||
for x in fragment:
|
||||
self.append(x)
|
||||
elif issubclass(type(fragment), BaseFragment):
|
||||
self.children.append(fragment)
|
||||
elif type(fragment) is tuple:
|
||||
# upgrade tuples to Fragments
|
||||
if type(part[0]) is not str or (type(part[1]) is not float and type(part[1]) is not int):
|
||||
if type(fragment[0]) is not str or (type(fragment[1]) is not float and type(fragment[1]) is not int):
|
||||
raise PromptParser.ParsingException(
|
||||
f"FlattenedPrompt cannot contain {part}, only Fragments or (str, float) tuples are allowed")
|
||||
parts_converted.append(Fragment(part[0], part[1]))
|
||||
f"FlattenedPrompt cannot contain {fragment}, only Fragments or (str, float) tuples are allowed")
|
||||
self.children.append(Fragment(fragment[0], fragment[1]))
|
||||
else:
|
||||
raise PromptParser.ParsingException(
|
||||
f"FlattenedPrompt cannot contain {part}, only Fragments or (str, float) tuples are allowed")
|
||||
# all looks good
|
||||
self.children = parts_converted
|
||||
f"FlattenedPrompt cannot contain {fragment}, only Fragments or (str, float) tuples are allowed")
|
||||
|
||||
def __repr__(self):
|
||||
return f"FlattenedPrompt:{self.children}"
|
||||
@ -42,9 +50,6 @@ class FlattenedPrompt():
|
||||
return type(other) is FlattenedPrompt and other.children == self.children
|
||||
|
||||
# abstract base class for Fragments
|
||||
class BaseFragment:
|
||||
pass
|
||||
|
||||
class Fragment(BaseFragment):
|
||||
def __init__(self, text: str, weight: float=1):
|
||||
assert(type(text) is str)
|
||||
|
@ -156,6 +156,18 @@ class PromptParserTestCase(unittest.TestCase):
|
||||
parse_prompt('("fire ++(flames)", "mountain 2(man)").blend(1,1)'))
|
||||
|
||||
def test_cross_attention_control(self):
|
||||
|
||||
self.assertEqual(Conjunction([
|
||||
FlattenedPrompt([Fragment('a', 1),
|
||||
CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)]),
|
||||
Fragment('eating a hotdog', 1)])]), parse_prompt("a \"cat\".swap(dog) eating a hotdog"))
|
||||
|
||||
self.assertEqual(Conjunction([
|
||||
FlattenedPrompt([Fragment('a', 1),
|
||||
CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)]),
|
||||
Fragment('eating a hotdog', 1)])]), parse_prompt("a cat.swap(dog) eating a hotdog"))
|
||||
|
||||
|
||||
fire_flames_to_trees = Conjunction([FlattenedPrompt([('fire', 1.0), \
|
||||
CrossAttentionControlSubstitute([Fragment('flames', 1)], [Fragment('trees', 1)])])])
|
||||
self.assertEqual(fire_flames_to_trees, parse_prompt('fire "flames".swap(trees)'))
|
||||
|
Loading…
Reference in New Issue
Block a user