bring in prompt parser from fix-prompts branch

attention is parsed but ignored, blends old syntax doesn't work,
	  conjunctions are parsed but ignored, the only part that's used
	  here is the new .blend() syntax and cross-attention control
	  using .swap()
This commit is contained in:
Damian at mba 2022-10-20 12:01:48 +02:00
parent 42883545f9
commit c9d27634b4
10 changed files with 169 additions and 61 deletions

View File

@ -4,7 +4,7 @@ weighted subprompts.
Useful function exports:
get_uc_and_c() get the conditioned and unconditioned latent
get_uc_and_c_and_ec() get the conditioned and unconditioned latent, and edited conditioning if we're doing cross-attention control
split_weighted_subpromopts() split subprompts, normalize and weight them
log_tokenization() print out colour-coded tokens and warn if truncated
@ -39,8 +39,10 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
pp = PromptParser()
parsed_prompt: Union[FlattenedPrompt, Blend] = pp.parse(prompt_string_cleaned)
parsed_negative_prompt: FlattenedPrompt = pp.parse(unconditioned_words)
# we don't support conjunctions for now
parsed_prompt: Union[FlattenedPrompt, Blend] = pp.parse(prompt_string_cleaned).prompts[0]
parsed_negative_prompt: FlattenedPrompt = pp.parse(unconditioned_words).prompts[0]
print("parsed prompt to", parsed_prompt)
conditioning = None
edited_conditioning = None
@ -50,7 +52,7 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
blend: Blend = parsed_prompt
embeddings_to_blend = None
for flattened_prompt in blend.prompts:
this_embedding = make_embeddings_for_flattened_prompt(model, flattened_prompt)
this_embedding = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt)
embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat(
(embeddings_to_blend, this_embedding))
conditioning, _ = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0),
@ -72,16 +74,16 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
# regular fragment
original_prompt.append(fragment)
edited_prompt.append(fragment)
original_embeddings, original_tokens = make_embeddings_for_flattened_prompt(model, original_prompt)
edited_embeddings, edited_tokens = make_embeddings_for_flattened_prompt(model, edited_prompt)
original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, original_prompt)
edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, edited_prompt)
conditioning = original_embeddings
edited_conditioning = edited_embeddings
edit_opcodes = build_token_edit_opcodes(original_tokens, edited_tokens)
else:
conditioning, _ = make_embeddings_for_flattened_prompt(model, flattened_prompt)
conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt)
unconditioning = make_embeddings_for_flattened_prompt(parsed_negative_prompt)
unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, parsed_negative_prompt)
return (unconditioning, conditioning, edited_conditioning, edit_opcodes)
@ -91,11 +93,11 @@ def build_token_edit_opcodes(original_tokens, edited_tokens):
return SequenceMatcher(None, original_tokens, edited_tokens).get_opcodes()
def make_embeddings_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt):
def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt):
if type(flattened_prompt) is not FlattenedPrompt:
raise f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead"
fragments = [x[0] for x in flattened_prompt.children]
embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True)
fragments = [x.text for x in flattened_prompt.children]
embeddings, tokens = model.get_learned_conditioning([' '.join(fragments)], return_tokens=True)
return embeddings, tokens

View File

@ -34,7 +34,7 @@ class Img2Img(Generator):
t_enc = int(strength * steps)
uc, c, ec, edit_opcodes = conditioning
structured_conditioning = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes)
extra_conditioning_info = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes)
def make_image(x_T):
# encode (scaled latent)
@ -52,7 +52,7 @@ class Img2Img(Generator):
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,
init_latent = self.init_latent,
structured_conditioning = structured_conditioning
extra_conditioning_info = extra_conditioning_info
# changes how noising is performed in ksampler
)

View File

@ -22,7 +22,7 @@ class Txt2Img(Generator):
"""
self.perlin = perlin
uc, c, ec, edit_opcodes = conditioning
structured_conditioning = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes)
extra_conditioning_info = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes)
@torch.no_grad()
def make_image(x_T):
@ -46,7 +46,7 @@ class Txt2Img(Generator):
verbose = False,
unconditional_guidance_scale = cfg_scale,
unconditional_conditioning = uc,
structured_conditioning = structured_conditioning,
extra_conditioning_info = extra_conditioning_info,
eta = ddim_eta,
img_callback = step_callback,
threshold = threshold,

View File

@ -24,7 +24,7 @@ class Txt2Img2Img(Generator):
kwargs are 'width' and 'height'
"""
uc, c, ec, edit_opcodes = conditioning
structured_conditioning = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes)
extra_conditioning_info = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes)
@torch.no_grad()
def make_image(x_T):
@ -63,7 +63,7 @@ class Txt2Img2Img(Generator):
unconditional_conditioning = uc,
eta = ddim_eta,
img_callback = step_callback,
structured_conditioning = structured_conditioning
extra_conditioning_info = extra_conditioning_info
)
print(
@ -97,7 +97,7 @@ class Txt2Img2Img(Generator):
img_callback = step_callback,
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,
structured_conditioning = structured_conditioning
extra_conditioning_info = extra_conditioning_info
)
if self.free_gpu_mem:

View File

@ -7,7 +7,7 @@ class Prompt():
def __init__(self, parts: list):
for c in parts:
if not issubclass(type(c), BaseFragment):
if type(c) is not Attention and not issubclass(type(c), BaseFragment):
raise PromptParser.ParsingException(f"Prompt cannot contain {type(c)}, only {BaseFragment.__subclasses__()} are allowed")
self.children = parts
def __repr__(self):
@ -56,6 +56,17 @@ class Fragment(BaseFragment):
and other.text == self.text \
and other.weight == self.weight
class Attention():
def __init__(self, weight: float, children: list):
self.weight = weight
self.children = children
#print(f"A: requested attention '{children}' to {weight}")
def __repr__(self):
return f"Attention:'{self.children}' @ {self.weight}"
def __eq__(self, other):
return type(other) is Attention and other.weight == self.weight and other.fragment == self.fragment
class CrossAttentionControlledFragment(BaseFragment):
pass
@ -65,7 +76,7 @@ class CrossAttentionControlSubstitute(CrossAttentionControlledFragment):
self.edited = edited
def __repr__(self):
return f"CrossAttentionControlSubstitute:('{self.original}'->'{self.edited}')"
return f"CrossAttentionControlSubstitute:({self.original}->{self.edited})"
def __eq__(self, other):
return type(other) is CrossAttentionControlSubstitute \
and other.original == self.original \
@ -137,7 +148,7 @@ class PromptParser():
self.root = self.build_parser_logic()
def parse(self, prompt: str) -> [list]:
def parse(self, prompt: str) -> Conjunction:
'''
:param prompt: The prompt string to parse
:return: a tuple
@ -181,13 +192,17 @@ class PromptParser():
for x in node:
results = flatten_internal(x, weight_scale, results, prefix+'pr')
#print(prefix, " ParseResults expanded, results is now", results)
elif issubclass(type(node), BaseFragment):
results.append(node)
#elif type(node) is Attention:
# #if node.weight < 1:
# # todo: inject a blend when flattening attention with weight <1"
# for c in node.children:
# results = flatten_internal(c, weight_scale*node.weight, results, prefix+' ')
elif type(node) is Attention:
# if node.weight < 1:
# todo: inject a blend when flattening attention with weight <1"
for c in node.children:
results = flatten_internal(c, weight_scale * node.weight, results, prefix + ' ')
elif type(node) is Fragment:
results += [Fragment(node.text, node.weight*weight_scale)]
elif type(node) is CrossAttentionControlSubstitute:
original = flatten_internal(node.original, weight_scale, [], ' CAo ')
edited = flatten_internal(node.edited, weight_scale, [], ' CAe ')
results += [CrossAttentionControlSubstitute(original, edited)]
elif type(node) is Blend:
flattened_subprompts = []
#print(" flattening blend with prompts", node.prompts, "weights", node.weights)
@ -204,7 +219,7 @@ class PromptParser():
#print(prefix + "after flattening Prompt, results is", results)
else:
raise PromptParser.ParsingException(f"unhandled node type {type(node)} when flattening {node}")
#print(prefix + "-> after flattening", type(node), "results is", results)
print(prefix + "-> after flattening", type(node), "results is", results)
return results
#print("flattening", root)
@ -239,32 +254,83 @@ class PromptParser():
else:
raise PromptParser.ParsingException("Cannot make fragment from " + str(x))
# attention control of the form +(phrase) / -(phrase) / <weight>(phrase)
# phrase can be multiple words, can have multiple +/- signs to increase the effect or type a floating point or integer weight
attention = pp.Forward()
attention_head = (number | pp.Word('+') | pp.Word('-'))\
.set_name("attention_head")\
.set_debug(False)
fragment_inside_attention = pp.CharsNotIn(SPACE_CHARS+'()')\
.set_parse_action(make_fragment)\
.set_name("fragment_inside_attention")\
.set_debug(False)
attention_with_parens = pp.Forward()
attention_with_parens_body = pp.nested_expr(content=pp.delimited_list((attention_with_parens | fragment_inside_attention), delim=SPACE_CHARS))
attention_with_parens << (attention_head + attention_with_parens_body)
def make_attention(x):
# print("making Attention from parsing with args", x0, x1)
weight = 1
# number(str)
if type(x[0]) is float or type(x[0]) is int:
weight = float(x[0])
# +(str) or -(str) or +str or -str
elif type(x[0]) is str:
base = self.attention_plus_base if x[0][0] == '+' else self.attention_minus_base
weight = pow(base, len(x[0]))
# print("Making attention with children of type", [str(type(x)) for x in x1])
return Attention(weight=weight, children=x[1])
attention_with_parens.set_parse_action(make_attention)\
.set_name("attention_with_parens")\
.set_debug(False)
# attention control of the form ++word --word (no parens)
attention_without_parens = (
(pp.Word('+') | pp.Word('-')) +
pp.CharsNotIn(SPACE_CHARS+'()').set_parse_action(lambda x: [[make_fragment(x)]])
)\
.set_name("attention_without_parens")\
.set_debug(False)
attention_without_parens.set_parse_action(make_attention)
attention << (attention_with_parens | attention_without_parens)\
.set_name("attention")\
.set_debug(False)
# cross-attention control
empty_string = ((lparen + rparen) |
pp.Literal('""').suppress() |
(lparen + pp.Literal('""').suppress() + rparen)
).set_parse_action(lambda x: Fragment(""))
original_words = (
(lparen + pp.Literal('"').suppress() + pp.CharsNotIn('"') + pp.Literal('"').suppress() + rparen).set_name('term1').set_debug(False) |
(pp.Literal('"').suppress() + pp.CharsNotIn('"') + pp.Literal('"').suppress()).set_name('term2').set_debug(False) |
(lparen + pp.CharsNotIn(')') + rparen).set_name('term3').set_debug(False)
(lparen + pp.Literal('"').suppress() + (pp.OneOrMore(attention) | pp.CharsNotIn('"').set_parse_action(make_fragment)) + pp.Literal('"').suppress() + rparen).set_name('term1').set_debug(False) |
(pp.Literal('"').suppress() + (pp.OneOrMore(attention) | pp.CharsNotIn('"').set_parse_action(make_fragment)) + pp.Literal('"').suppress()).set_name('term2').set_debug(False) |
(lparen + (pp.OneOrMore(attention) | pp.CharsNotIn(')').set_parse_action(make_fragment)) + rparen).set_name('term3').set_debug(False)
).set_name('original_words')
edited_words = (
(pp.Literal('"').suppress() + pp.CharsNotIn('"') + pp.Literal('"').suppress()).set_name('termA').set_debug(False) |
pp.CharsNotIn(')').set_name('termB').set_debug(False)
(pp.Literal('"').suppress() + (pp.OneOrMore(attention) | pp.CharsNotIn('"').set_parse_action(make_fragment)) + pp.Literal('"').suppress()).set_name('termA').set_debug(False) |
pp.Literal('""').suppress().set_parse_action(lambda x: Fragment("")) |
(pp.OneOrMore(attention) | pp.CharsNotIn(')').set_parse_action(make_fragment)).set_name('termB').set_debug(True)
).set_name('edited_words')
cross_attention_substitute = original_words + \
cross_attention_substitute = (empty_string | original_words) + \
pp.Literal(".swap").suppress() + \
lparen + edited_words + rparen
(empty_string | (lparen + edited_words + rparen)
)
cross_attention_substitute.set_name('cross_attention_substitute')
def make_cross_attention_substitute(x):
#print("making cacs for", x)
return CrossAttentionControlSubstitute(x[0], x[1])
cacs = CrossAttentionControlSubstitute(x[0], x[1])
#print("made", cacs)
#return cacs
return cacs
cross_attention_substitute.set_parse_action(make_cross_attention_substitute)
# simple fragments of text
prompt_part << (cross_attention_substitute
#| attention
| attention
| word
)
prompt_part.set_debug(False)

View File

@ -16,10 +16,10 @@ class DDIMSampler(Sampler):
def prepare_to_sample(self, t_enc, **kwargs):
super().prepare_to_sample(t_enc, **kwargs)
structured_conditioning = kwargs.get('structured_conditioning', None)
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
if structured_conditioning is not None and structured_conditioning.wants_cross_attention_control:
self.invokeai_diffuser.setup_cross_attention_control(structured_conditioning)
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info)
else:
self.invokeai_diffuser.remove_cross_attention_control()

View File

@ -34,10 +34,10 @@ class CFGDenoiser(nn.Module):
def prepare_to_sample(self, t_enc, **kwargs):
structured_conditioning = kwargs.get('structured_conditioning', None)
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
if structured_conditioning is not None and structured_conditioning.wants_cross_attention_control:
self.invokeai_diffuser.setup_cross_attention_control(structured_conditioning)
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info)
else:
self.invokeai_diffuser.remove_cross_attention_control()
@ -164,7 +164,7 @@ class KSampler(Sampler):
log_every_t=100,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
structured_conditioning=None,
extra_conditioning_info=None,
threshold = 0,
perlin = 0,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
@ -197,7 +197,7 @@ class KSampler(Sampler):
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0]
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
model_wrap_cfg.prepare_to_sample(S, structured_conditioning=structured_conditioning)
model_wrap_cfg.prepare_to_sample(S, extra_conditioning_info=extra_conditioning_info)
extra_args = {
'cond': conditioning,
'uncond': unconditional_conditioning,
@ -224,7 +224,7 @@ class KSampler(Sampler):
index,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
structured_conditioning=None,
extra_conditioning_info=None,
**kwargs,
):
if self.model_wrap is None:
@ -250,7 +250,7 @@ class KSampler(Sampler):
# so the actual formula for indexing into sigmas:
# sigma_index = (steps-index)
s_index = t_enc - index - 1
self.model_wrap.prepare_to_sample(s_index, structured_conditioning=structured_conditioning)
self.model_wrap.prepare_to_sample(s_index, extra_conditioning_info=extra_conditioning_info)
img = K.sampling.__dict__[f'_{self.schedule}'](
self.model_wrap,
img,

View File

@ -20,10 +20,10 @@ class PLMSSampler(Sampler):
def prepare_to_sample(self, t_enc, **kwargs):
super().prepare_to_sample(t_enc, **kwargs)
structured_conditioning = kwargs.get('structured_conditioning', None)
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
if structured_conditioning is not None and structured_conditioning.wants_cross_attention_control:
self.invokeai_diffuser.setup_cross_attention_control(structured_conditioning)
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info)
else:
self.invokeai_diffuser.remove_cross_attention_control()

View File

@ -439,6 +439,13 @@ class FrozenCLIPEmbedder(AbstractEncoder):
param.requires_grad = False
def forward(self, text, **kwargs):
should_return_tokens = False
if 'return_tokens' in kwargs:
should_return_tokens = kwargs.get('return_tokens', False)
# self.transformer doesn't like having extra kwargs
kwargs.pop('return_tokens')
batch_encoding = self.tokenizer(
text,
truncation=True,
@ -451,7 +458,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
tokens = batch_encoding['input_ids'].to(self.device)
z = self.transformer(input_ids=tokens, **kwargs)
if kwargs.get('return_tokens', False):
if should_return_tokens:
return z, tokens
else:
return z

View File

@ -1,6 +1,7 @@
import unittest
from ldm.invoke.prompt_parser import PromptParser, Blend, Conjunction, FlattenedPrompt, CrossAttentionControlSubstitute
from ldm.invoke.prompt_parser import PromptParser, Blend, Conjunction, FlattenedPrompt, CrossAttentionControlSubstitute, \
Fragment
def parse_prompt(prompt_string):
@ -135,7 +136,7 @@ class PromptParserTestCase(unittest.TestCase):
def test_cross_attention_control(self):
fire_flames_to_trees = Conjunction([FlattenedPrompt([('fire', 1.0), \
CrossAttentionControlSubstitute('flames', 'trees')])])
CrossAttentionControlSubstitute([Fragment('flames', 1)], [Fragment('trees', 1)])])])
self.assertEqual(fire_flames_to_trees, parse_prompt('fire "flames".swap(trees)'))
self.assertEqual(fire_flames_to_trees, parse_prompt('fire (flames).swap(trees)'))
self.assertEqual(fire_flames_to_trees, parse_prompt('fire ("flames").swap(trees)'))
@ -144,13 +145,13 @@ class PromptParserTestCase(unittest.TestCase):
self.assertEqual(fire_flames_to_trees, parse_prompt('fire ("flames").swap("trees")'))
fire_flames_to_trees_and_houses = Conjunction([FlattenedPrompt([('fire', 1.0), \
CrossAttentionControlSubstitute('flames', 'trees and houses')])])
CrossAttentionControlSubstitute([Fragment('flames', 1)], [Fragment('trees and houses', 1)])])])
self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire ("flames").swap("trees and houses")'))
self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire (flames).swap("trees and houses")'))
self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire "flames".swap("trees and houses")'))
trees_and_houses_to_flames = Conjunction([FlattenedPrompt([('fire', 1.0), \
CrossAttentionControlSubstitute('trees and houses', 'flames')])])
CrossAttentionControlSubstitute([Fragment('trees and houses', 1)], [Fragment('flames',1)])])])
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire ("trees and houses").swap("flames")'))
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire (trees and houses).swap("flames")'))
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire "trees and houses".swap("flames")'))
@ -159,14 +160,46 @@ class PromptParserTestCase(unittest.TestCase):
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire "trees and houses".swap(flames)'))
flames_to_trees_fire = Conjunction([FlattenedPrompt([
CrossAttentionControlSubstitute('flames', 'trees'),
CrossAttentionControlSubstitute([Fragment('flames',1)], [Fragment('trees',1)]),
(', fire', 1.0)])])
self.assertEqual(flames_to_trees_fire, parse_prompt('"flames".swap(trees), fire'))
self.assertEqual(flames_to_trees_fire, parse_prompt('(flames).swap(trees), fire '))
self.assertEqual(flames_to_trees_fire, parse_prompt('("flames").swap(trees), fire '))
self.assertEqual(flames_to_trees_fire, parse_prompt('"flames".swap("trees"), fire'))
self.assertEqual(flames_to_trees_fire, parse_prompt('(flames).swap("trees"), fire'))
self.assertEqual(flames_to_trees_fire, parse_prompt('("flames").swap("trees"), fire'))
self.assertEqual(flames_to_trees_fire, parse_prompt('"flames".swap(trees), fire'))
self.assertEqual(flames_to_trees_fire, parse_prompt('(flames).swap(trees), fire '))
self.assertEqual(flames_to_trees_fire, parse_prompt('("flames").swap(trees), fire '))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]),
parse_prompt('a forest landscape "".swap("in winter")'))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
CrossAttentionControlSubstitute([Fragment(' ',1)], [Fragment('in winter',1)])])]),
parse_prompt('a forest landscape " ".swap("in winter")'))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]),
parse_prompt('a forest landscape "in winter".swap("")'))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]),
parse_prompt('a forest landscape "in winter".swap()'))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment(' ',1)])])]),
parse_prompt('a forest landscape "in winter".swap(" ")'))
def test_cross_attention_control_with_attention(self):
flames_to_trees_fire = Conjunction([FlattenedPrompt([
CrossAttentionControlSubstitute([Fragment('flames',0.5)], [Fragment('trees',0.7)]),
Fragment(',', 1), Fragment('fire', 2.0)])])
self.assertEqual(flames_to_trees_fire, parse_prompt('"0.5(flames)".swap("0.7(trees)"), 2.0(fire)'))
flames_to_trees_fire = Conjunction([FlattenedPrompt([
CrossAttentionControlSubstitute([Fragment('fire',0.5), Fragment('flames',0.25)], [Fragment('trees',0.7)]),
Fragment(',', 1), Fragment('fire', 2.0)])])
self.assertEqual(flames_to_trees_fire, parse_prompt('"0.5(fire 0.5(flames))".swap("0.7(trees)"), 2.0(fire)'))
flames_to_trees_fire = Conjunction([FlattenedPrompt([
CrossAttentionControlSubstitute([Fragment('fire',0.5), Fragment('flames',0.25)], [Fragment('trees',0.7), Fragment('houses', 1)]),
Fragment(',', 1), Fragment('fire', 2.0)])])
self.assertEqual(flames_to_trees_fire, parse_prompt('"0.5(fire 0.5(flames))".swap("0.7(trees) houses"), 2.0(fire)'))
if __name__ == '__main__':