mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix prompt handling in conditioning.py
This commit is contained in:
parent
3f13dd3ae8
commit
da88097aba
@ -66,10 +66,10 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
|
|||||||
edited_prompt = FlattenedPrompt()
|
edited_prompt = FlattenedPrompt()
|
||||||
for fragment in flattened_prompt.children:
|
for fragment in flattened_prompt.children:
|
||||||
if type(fragment) is CrossAttentionControlSubstitute:
|
if type(fragment) is CrossAttentionControlSubstitute:
|
||||||
original_prompt.append(fragment.original_fragment)
|
original_prompt.append(fragment.original)
|
||||||
edited_prompt.append(fragment.edited_fragment)
|
edited_prompt.append(fragment.edited)
|
||||||
elif type(fragment) is CrossAttentionControlAppend:
|
#elif type(fragment) is CrossAttentionControlAppend:
|
||||||
edited_prompt.append(fragment.fragment)
|
# edited_prompt.append(fragment.fragment)
|
||||||
else:
|
else:
|
||||||
# regular fragment
|
# regular fragment
|
||||||
original_prompt.append(fragment)
|
original_prompt.append(fragment)
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import string
|
import string
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
import pyparsing
|
import pyparsing
|
||||||
import pyparsing as pp
|
import pyparsing as pp
|
||||||
@ -17,24 +18,31 @@ class Prompt():
|
|||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return type(other) is Prompt and other.children == self.children
|
return type(other) is Prompt and other.children == self.children
|
||||||
|
|
||||||
|
class BaseFragment:
|
||||||
|
pass
|
||||||
|
|
||||||
class FlattenedPrompt():
|
class FlattenedPrompt():
|
||||||
def __init__(self, parts: list):
|
def __init__(self, parts: list=[]):
|
||||||
# verify type correctness
|
# verify type correctness
|
||||||
parts_converted = []
|
self.children = []
|
||||||
for part in parts:
|
for part in parts:
|
||||||
if issubclass(type(part), BaseFragment):
|
self.append(part)
|
||||||
parts_converted.append(part)
|
|
||||||
elif type(part) is tuple:
|
def append(self, fragment: Union[list, BaseFragment, tuple]):
|
||||||
# upgrade tuples to Fragments
|
if type(fragment) is list:
|
||||||
if type(part[0]) is not str or (type(part[1]) is not float and type(part[1]) is not int):
|
for x in fragment:
|
||||||
raise PromptParser.ParsingException(
|
self.append(x)
|
||||||
f"FlattenedPrompt cannot contain {part}, only Fragments or (str, float) tuples are allowed")
|
elif issubclass(type(fragment), BaseFragment):
|
||||||
parts_converted.append(Fragment(part[0], part[1]))
|
self.children.append(fragment)
|
||||||
else:
|
elif type(fragment) is tuple:
|
||||||
|
# upgrade tuples to Fragments
|
||||||
|
if type(fragment[0]) is not str or (type(fragment[1]) is not float and type(fragment[1]) is not int):
|
||||||
raise PromptParser.ParsingException(
|
raise PromptParser.ParsingException(
|
||||||
f"FlattenedPrompt cannot contain {part}, only Fragments or (str, float) tuples are allowed")
|
f"FlattenedPrompt cannot contain {fragment}, only Fragments or (str, float) tuples are allowed")
|
||||||
# all looks good
|
self.children.append(Fragment(fragment[0], fragment[1]))
|
||||||
self.children = parts_converted
|
else:
|
||||||
|
raise PromptParser.ParsingException(
|
||||||
|
f"FlattenedPrompt cannot contain {fragment}, only Fragments or (str, float) tuples are allowed")
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"FlattenedPrompt:{self.children}"
|
return f"FlattenedPrompt:{self.children}"
|
||||||
@ -42,9 +50,6 @@ class FlattenedPrompt():
|
|||||||
return type(other) is FlattenedPrompt and other.children == self.children
|
return type(other) is FlattenedPrompt and other.children == self.children
|
||||||
|
|
||||||
# abstract base class for Fragments
|
# abstract base class for Fragments
|
||||||
class BaseFragment:
|
|
||||||
pass
|
|
||||||
|
|
||||||
class Fragment(BaseFragment):
|
class Fragment(BaseFragment):
|
||||||
def __init__(self, text: str, weight: float=1):
|
def __init__(self, text: str, weight: float=1):
|
||||||
assert(type(text) is str)
|
assert(type(text) is str)
|
||||||
|
@ -156,6 +156,18 @@ class PromptParserTestCase(unittest.TestCase):
|
|||||||
parse_prompt('("fire ++(flames)", "mountain 2(man)").blend(1,1)'))
|
parse_prompt('("fire ++(flames)", "mountain 2(man)").blend(1,1)'))
|
||||||
|
|
||||||
def test_cross_attention_control(self):
|
def test_cross_attention_control(self):
|
||||||
|
|
||||||
|
self.assertEqual(Conjunction([
|
||||||
|
FlattenedPrompt([Fragment('a', 1),
|
||||||
|
CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)]),
|
||||||
|
Fragment('eating a hotdog', 1)])]), parse_prompt("a \"cat\".swap(dog) eating a hotdog"))
|
||||||
|
|
||||||
|
self.assertEqual(Conjunction([
|
||||||
|
FlattenedPrompt([Fragment('a', 1),
|
||||||
|
CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)]),
|
||||||
|
Fragment('eating a hotdog', 1)])]), parse_prompt("a cat.swap(dog) eating a hotdog"))
|
||||||
|
|
||||||
|
|
||||||
fire_flames_to_trees = Conjunction([FlattenedPrompt([('fire', 1.0), \
|
fire_flames_to_trees = Conjunction([FlattenedPrompt([('fire', 1.0), \
|
||||||
CrossAttentionControlSubstitute([Fragment('flames', 1)], [Fragment('trees', 1)])])])
|
CrossAttentionControlSubstitute([Fragment('flames', 1)], [Fragment('trees', 1)])])])
|
||||||
self.assertEqual(fire_flames_to_trees, parse_prompt('fire "flames".swap(trees)'))
|
self.assertEqual(fire_flames_to_trees, parse_prompt('fire "flames".swap(trees)'))
|
||||||
|
Loading…
Reference in New Issue
Block a user