bring in prompt parser from fix-prompts branch

attention is parsed but ignored, blends old syntax doesn't work,
	  conjunctions are parsed but ignored, the only part that's used
	  here is the new .blend() syntax and cross-attention control
	  using .swap()
This commit is contained in:
Damian at mba 2022-10-20 12:01:48 +02:00
parent 42883545f9
commit c9d27634b4
10 changed files with 169 additions and 61 deletions

View File

@ -4,7 +4,7 @@ weighted subprompts.
Useful function exports: Useful function exports:
get_uc_and_c() get the conditioned and unconditioned latent get_uc_and_c_and_ec() get the conditioned and unconditioned latent, and edited conditioning if we're doing cross-attention control
split_weighted_subpromopts() split subprompts, normalize and weight them split_weighted_subpromopts() split subprompts, normalize and weight them
log_tokenization() print out colour-coded tokens and warn if truncated log_tokenization() print out colour-coded tokens and warn if truncated
@ -39,8 +39,10 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
pp = PromptParser() pp = PromptParser()
parsed_prompt: Union[FlattenedPrompt, Blend] = pp.parse(prompt_string_cleaned) # we don't support conjunctions for now
parsed_negative_prompt: FlattenedPrompt = pp.parse(unconditioned_words) parsed_prompt: Union[FlattenedPrompt, Blend] = pp.parse(prompt_string_cleaned).prompts[0]
parsed_negative_prompt: FlattenedPrompt = pp.parse(unconditioned_words).prompts[0]
print("parsed prompt to", parsed_prompt)
conditioning = None conditioning = None
edited_conditioning = None edited_conditioning = None
@ -50,7 +52,7 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
blend: Blend = parsed_prompt blend: Blend = parsed_prompt
embeddings_to_blend = None embeddings_to_blend = None
for flattened_prompt in blend.prompts: for flattened_prompt in blend.prompts:
this_embedding = make_embeddings_for_flattened_prompt(model, flattened_prompt) this_embedding = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt)
embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat( embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat(
(embeddings_to_blend, this_embedding)) (embeddings_to_blend, this_embedding))
conditioning, _ = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0), conditioning, _ = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0),
@ -72,16 +74,16 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
# regular fragment # regular fragment
original_prompt.append(fragment) original_prompt.append(fragment)
edited_prompt.append(fragment) edited_prompt.append(fragment)
original_embeddings, original_tokens = make_embeddings_for_flattened_prompt(model, original_prompt) original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, original_prompt)
edited_embeddings, edited_tokens = make_embeddings_for_flattened_prompt(model, edited_prompt) edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, edited_prompt)
conditioning = original_embeddings conditioning = original_embeddings
edited_conditioning = edited_embeddings edited_conditioning = edited_embeddings
edit_opcodes = build_token_edit_opcodes(original_tokens, edited_tokens) edit_opcodes = build_token_edit_opcodes(original_tokens, edited_tokens)
else: else:
conditioning, _ = make_embeddings_for_flattened_prompt(model, flattened_prompt) conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt)
unconditioning = make_embeddings_for_flattened_prompt(parsed_negative_prompt) unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, parsed_negative_prompt)
return (unconditioning, conditioning, edited_conditioning, edit_opcodes) return (unconditioning, conditioning, edited_conditioning, edit_opcodes)
@ -91,11 +93,11 @@ def build_token_edit_opcodes(original_tokens, edited_tokens):
return SequenceMatcher(None, original_tokens, edited_tokens).get_opcodes() return SequenceMatcher(None, original_tokens, edited_tokens).get_opcodes()
def make_embeddings_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt): def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt):
if type(flattened_prompt) is not FlattenedPrompt: if type(flattened_prompt) is not FlattenedPrompt:
raise f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead" raise f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead"
fragments = [x[0] for x in flattened_prompt.children] fragments = [x.text for x in flattened_prompt.children]
embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True) embeddings, tokens = model.get_learned_conditioning([' '.join(fragments)], return_tokens=True)
return embeddings, tokens return embeddings, tokens

View File

@ -34,7 +34,7 @@ class Img2Img(Generator):
t_enc = int(strength * steps) t_enc = int(strength * steps)
uc, c, ec, edit_opcodes = conditioning uc, c, ec, edit_opcodes = conditioning
structured_conditioning = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes) extra_conditioning_info = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes)
def make_image(x_T): def make_image(x_T):
# encode (scaled latent) # encode (scaled latent)
@ -52,7 +52,7 @@ class Img2Img(Generator):
unconditional_guidance_scale=cfg_scale, unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc, unconditional_conditioning=uc,
init_latent = self.init_latent, init_latent = self.init_latent,
structured_conditioning = structured_conditioning extra_conditioning_info = extra_conditioning_info
# changes how noising is performed in ksampler # changes how noising is performed in ksampler
) )

View File

@ -22,7 +22,7 @@ class Txt2Img(Generator):
""" """
self.perlin = perlin self.perlin = perlin
uc, c, ec, edit_opcodes = conditioning uc, c, ec, edit_opcodes = conditioning
structured_conditioning = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes) extra_conditioning_info = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes)
@torch.no_grad() @torch.no_grad()
def make_image(x_T): def make_image(x_T):
@ -46,7 +46,7 @@ class Txt2Img(Generator):
verbose = False, verbose = False,
unconditional_guidance_scale = cfg_scale, unconditional_guidance_scale = cfg_scale,
unconditional_conditioning = uc, unconditional_conditioning = uc,
structured_conditioning = structured_conditioning, extra_conditioning_info = extra_conditioning_info,
eta = ddim_eta, eta = ddim_eta,
img_callback = step_callback, img_callback = step_callback,
threshold = threshold, threshold = threshold,

View File

@ -24,7 +24,7 @@ class Txt2Img2Img(Generator):
kwargs are 'width' and 'height' kwargs are 'width' and 'height'
""" """
uc, c, ec, edit_opcodes = conditioning uc, c, ec, edit_opcodes = conditioning
structured_conditioning = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes) extra_conditioning_info = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes)
@torch.no_grad() @torch.no_grad()
def make_image(x_T): def make_image(x_T):
@ -63,7 +63,7 @@ class Txt2Img2Img(Generator):
unconditional_conditioning = uc, unconditional_conditioning = uc,
eta = ddim_eta, eta = ddim_eta,
img_callback = step_callback, img_callback = step_callback,
structured_conditioning = structured_conditioning extra_conditioning_info = extra_conditioning_info
) )
print( print(
@ -97,7 +97,7 @@ class Txt2Img2Img(Generator):
img_callback = step_callback, img_callback = step_callback,
unconditional_guidance_scale=cfg_scale, unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc, unconditional_conditioning=uc,
structured_conditioning = structured_conditioning extra_conditioning_info = extra_conditioning_info
) )
if self.free_gpu_mem: if self.free_gpu_mem:

View File

@ -7,7 +7,7 @@ class Prompt():
def __init__(self, parts: list): def __init__(self, parts: list):
for c in parts: for c in parts:
if not issubclass(type(c), BaseFragment): if type(c) is not Attention and not issubclass(type(c), BaseFragment):
raise PromptParser.ParsingException(f"Prompt cannot contain {type(c)}, only {BaseFragment.__subclasses__()} are allowed") raise PromptParser.ParsingException(f"Prompt cannot contain {type(c)}, only {BaseFragment.__subclasses__()} are allowed")
self.children = parts self.children = parts
def __repr__(self): def __repr__(self):
@ -56,6 +56,17 @@ class Fragment(BaseFragment):
and other.text == self.text \ and other.text == self.text \
and other.weight == self.weight and other.weight == self.weight
class Attention():
def __init__(self, weight: float, children: list):
self.weight = weight
self.children = children
#print(f"A: requested attention '{children}' to {weight}")
def __repr__(self):
return f"Attention:'{self.children}' @ {self.weight}"
def __eq__(self, other):
return type(other) is Attention and other.weight == self.weight and other.fragment == self.fragment
class CrossAttentionControlledFragment(BaseFragment): class CrossAttentionControlledFragment(BaseFragment):
pass pass
@ -65,7 +76,7 @@ class CrossAttentionControlSubstitute(CrossAttentionControlledFragment):
self.edited = edited self.edited = edited
def __repr__(self): def __repr__(self):
return f"CrossAttentionControlSubstitute:('{self.original}'->'{self.edited}')" return f"CrossAttentionControlSubstitute:({self.original}->{self.edited})"
def __eq__(self, other): def __eq__(self, other):
return type(other) is CrossAttentionControlSubstitute \ return type(other) is CrossAttentionControlSubstitute \
and other.original == self.original \ and other.original == self.original \
@ -137,7 +148,7 @@ class PromptParser():
self.root = self.build_parser_logic() self.root = self.build_parser_logic()
def parse(self, prompt: str) -> [list]: def parse(self, prompt: str) -> Conjunction:
''' '''
:param prompt: The prompt string to parse :param prompt: The prompt string to parse
:return: a tuple :return: a tuple
@ -181,13 +192,17 @@ class PromptParser():
for x in node: for x in node:
results = flatten_internal(x, weight_scale, results, prefix+'pr') results = flatten_internal(x, weight_scale, results, prefix+'pr')
#print(prefix, " ParseResults expanded, results is now", results) #print(prefix, " ParseResults expanded, results is now", results)
elif issubclass(type(node), BaseFragment): elif type(node) is Attention:
results.append(node) # if node.weight < 1:
#elif type(node) is Attention: # todo: inject a blend when flattening attention with weight <1"
# #if node.weight < 1: for c in node.children:
# # todo: inject a blend when flattening attention with weight <1" results = flatten_internal(c, weight_scale * node.weight, results, prefix + ' ')
# for c in node.children: elif type(node) is Fragment:
# results = flatten_internal(c, weight_scale*node.weight, results, prefix+' ') results += [Fragment(node.text, node.weight*weight_scale)]
elif type(node) is CrossAttentionControlSubstitute:
original = flatten_internal(node.original, weight_scale, [], ' CAo ')
edited = flatten_internal(node.edited, weight_scale, [], ' CAe ')
results += [CrossAttentionControlSubstitute(original, edited)]
elif type(node) is Blend: elif type(node) is Blend:
flattened_subprompts = [] flattened_subprompts = []
#print(" flattening blend with prompts", node.prompts, "weights", node.weights) #print(" flattening blend with prompts", node.prompts, "weights", node.weights)
@ -204,7 +219,7 @@ class PromptParser():
#print(prefix + "after flattening Prompt, results is", results) #print(prefix + "after flattening Prompt, results is", results)
else: else:
raise PromptParser.ParsingException(f"unhandled node type {type(node)} when flattening {node}") raise PromptParser.ParsingException(f"unhandled node type {type(node)} when flattening {node}")
#print(prefix + "-> after flattening", type(node), "results is", results) print(prefix + "-> after flattening", type(node), "results is", results)
return results return results
#print("flattening", root) #print("flattening", root)
@ -239,32 +254,83 @@ class PromptParser():
else: else:
raise PromptParser.ParsingException("Cannot make fragment from " + str(x)) raise PromptParser.ParsingException("Cannot make fragment from " + str(x))
# attention control of the form +(phrase) / -(phrase) / <weight>(phrase)
# phrase can be multiple words, can have multiple +/- signs to increase the effect or type a floating point or integer weight
attention = pp.Forward()
attention_head = (number | pp.Word('+') | pp.Word('-'))\
.set_name("attention_head")\
.set_debug(False)
fragment_inside_attention = pp.CharsNotIn(SPACE_CHARS+'()')\
.set_parse_action(make_fragment)\
.set_name("fragment_inside_attention")\
.set_debug(False)
attention_with_parens = pp.Forward()
attention_with_parens_body = pp.nested_expr(content=pp.delimited_list((attention_with_parens | fragment_inside_attention), delim=SPACE_CHARS))
attention_with_parens << (attention_head + attention_with_parens_body)
def make_attention(x):
# print("making Attention from parsing with args", x0, x1)
weight = 1
# number(str)
if type(x[0]) is float or type(x[0]) is int:
weight = float(x[0])
# +(str) or -(str) or +str or -str
elif type(x[0]) is str:
base = self.attention_plus_base if x[0][0] == '+' else self.attention_minus_base
weight = pow(base, len(x[0]))
# print("Making attention with children of type", [str(type(x)) for x in x1])
return Attention(weight=weight, children=x[1])
attention_with_parens.set_parse_action(make_attention)\
.set_name("attention_with_parens")\
.set_debug(False)
# attention control of the form ++word --word (no parens)
attention_without_parens = (
(pp.Word('+') | pp.Word('-')) +
pp.CharsNotIn(SPACE_CHARS+'()').set_parse_action(lambda x: [[make_fragment(x)]])
)\
.set_name("attention_without_parens")\
.set_debug(False)
attention_without_parens.set_parse_action(make_attention)
attention << (attention_with_parens | attention_without_parens)\
.set_name("attention")\
.set_debug(False)
# cross-attention control
empty_string = ((lparen + rparen) |
pp.Literal('""').suppress() |
(lparen + pp.Literal('""').suppress() + rparen)
).set_parse_action(lambda x: Fragment(""))
original_words = ( original_words = (
(lparen + pp.Literal('"').suppress() + pp.CharsNotIn('"') + pp.Literal('"').suppress() + rparen).set_name('term1').set_debug(False) | (lparen + pp.Literal('"').suppress() + (pp.OneOrMore(attention) | pp.CharsNotIn('"').set_parse_action(make_fragment)) + pp.Literal('"').suppress() + rparen).set_name('term1').set_debug(False) |
(pp.Literal('"').suppress() + pp.CharsNotIn('"') + pp.Literal('"').suppress()).set_name('term2').set_debug(False) | (pp.Literal('"').suppress() + (pp.OneOrMore(attention) | pp.CharsNotIn('"').set_parse_action(make_fragment)) + pp.Literal('"').suppress()).set_name('term2').set_debug(False) |
(lparen + pp.CharsNotIn(')') + rparen).set_name('term3').set_debug(False) (lparen + (pp.OneOrMore(attention) | pp.CharsNotIn(')').set_parse_action(make_fragment)) + rparen).set_name('term3').set_debug(False)
).set_name('original_words') ).set_name('original_words')
edited_words = ( edited_words = (
(pp.Literal('"').suppress() + pp.CharsNotIn('"') + pp.Literal('"').suppress()).set_name('termA').set_debug(False) | (pp.Literal('"').suppress() + (pp.OneOrMore(attention) | pp.CharsNotIn('"').set_parse_action(make_fragment)) + pp.Literal('"').suppress()).set_name('termA').set_debug(False) |
pp.CharsNotIn(')').set_name('termB').set_debug(False) pp.Literal('""').suppress().set_parse_action(lambda x: Fragment("")) |
(pp.OneOrMore(attention) | pp.CharsNotIn(')').set_parse_action(make_fragment)).set_name('termB').set_debug(True)
).set_name('edited_words') ).set_name('edited_words')
cross_attention_substitute = original_words + \ cross_attention_substitute = (empty_string | original_words) + \
pp.Literal(".swap").suppress() + \ pp.Literal(".swap").suppress() + \
lparen + edited_words + rparen (empty_string | (lparen + edited_words + rparen)
)
cross_attention_substitute.set_name('cross_attention_substitute') cross_attention_substitute.set_name('cross_attention_substitute')
def make_cross_attention_substitute(x): def make_cross_attention_substitute(x):
#print("making cacs for", x) #print("making cacs for", x)
return CrossAttentionControlSubstitute(x[0], x[1]) cacs = CrossAttentionControlSubstitute(x[0], x[1])
#print("made", cacs) #print("made", cacs)
#return cacs return cacs
cross_attention_substitute.set_parse_action(make_cross_attention_substitute) cross_attention_substitute.set_parse_action(make_cross_attention_substitute)
# simple fragments of text # simple fragments of text
prompt_part << (cross_attention_substitute prompt_part << (cross_attention_substitute
#| attention | attention
| word | word
) )
prompt_part.set_debug(False) prompt_part.set_debug(False)

View File

@ -16,10 +16,10 @@ class DDIMSampler(Sampler):
def prepare_to_sample(self, t_enc, **kwargs): def prepare_to_sample(self, t_enc, **kwargs):
super().prepare_to_sample(t_enc, **kwargs) super().prepare_to_sample(t_enc, **kwargs)
structured_conditioning = kwargs.get('structured_conditioning', None) extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
if structured_conditioning is not None and structured_conditioning.wants_cross_attention_control: if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
self.invokeai_diffuser.setup_cross_attention_control(structured_conditioning) self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info)
else: else:
self.invokeai_diffuser.remove_cross_attention_control() self.invokeai_diffuser.remove_cross_attention_control()

View File

@ -34,10 +34,10 @@ class CFGDenoiser(nn.Module):
def prepare_to_sample(self, t_enc, **kwargs): def prepare_to_sample(self, t_enc, **kwargs):
structured_conditioning = kwargs.get('structured_conditioning', None) extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
if structured_conditioning is not None and structured_conditioning.wants_cross_attention_control: if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
self.invokeai_diffuser.setup_cross_attention_control(structured_conditioning) self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info)
else: else:
self.invokeai_diffuser.remove_cross_attention_control() self.invokeai_diffuser.remove_cross_attention_control()
@ -164,7 +164,7 @@ class KSampler(Sampler):
log_every_t=100, log_every_t=100,
unconditional_guidance_scale=1.0, unconditional_guidance_scale=1.0,
unconditional_conditioning=None, unconditional_conditioning=None,
structured_conditioning=None, extra_conditioning_info=None,
threshold = 0, threshold = 0,
perlin = 0, perlin = 0,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
@ -197,7 +197,7 @@ class KSampler(Sampler):
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0]
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10)) model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
model_wrap_cfg.prepare_to_sample(S, structured_conditioning=structured_conditioning) model_wrap_cfg.prepare_to_sample(S, extra_conditioning_info=extra_conditioning_info)
extra_args = { extra_args = {
'cond': conditioning, 'cond': conditioning,
'uncond': unconditional_conditioning, 'uncond': unconditional_conditioning,
@ -224,7 +224,7 @@ class KSampler(Sampler):
index, index,
unconditional_guidance_scale=1.0, unconditional_guidance_scale=1.0,
unconditional_conditioning=None, unconditional_conditioning=None,
structured_conditioning=None, extra_conditioning_info=None,
**kwargs, **kwargs,
): ):
if self.model_wrap is None: if self.model_wrap is None:
@ -250,7 +250,7 @@ class KSampler(Sampler):
# so the actual formula for indexing into sigmas: # so the actual formula for indexing into sigmas:
# sigma_index = (steps-index) # sigma_index = (steps-index)
s_index = t_enc - index - 1 s_index = t_enc - index - 1
self.model_wrap.prepare_to_sample(s_index, structured_conditioning=structured_conditioning) self.model_wrap.prepare_to_sample(s_index, extra_conditioning_info=extra_conditioning_info)
img = K.sampling.__dict__[f'_{self.schedule}']( img = K.sampling.__dict__[f'_{self.schedule}'](
self.model_wrap, self.model_wrap,
img, img,

View File

@ -20,10 +20,10 @@ class PLMSSampler(Sampler):
def prepare_to_sample(self, t_enc, **kwargs): def prepare_to_sample(self, t_enc, **kwargs):
super().prepare_to_sample(t_enc, **kwargs) super().prepare_to_sample(t_enc, **kwargs)
structured_conditioning = kwargs.get('structured_conditioning', None) extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
if structured_conditioning is not None and structured_conditioning.wants_cross_attention_control: if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
self.invokeai_diffuser.setup_cross_attention_control(structured_conditioning) self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info)
else: else:
self.invokeai_diffuser.remove_cross_attention_control() self.invokeai_diffuser.remove_cross_attention_control()

View File

@ -439,6 +439,13 @@ class FrozenCLIPEmbedder(AbstractEncoder):
param.requires_grad = False param.requires_grad = False
def forward(self, text, **kwargs): def forward(self, text, **kwargs):
should_return_tokens = False
if 'return_tokens' in kwargs:
should_return_tokens = kwargs.get('return_tokens', False)
# self.transformer doesn't like having extra kwargs
kwargs.pop('return_tokens')
batch_encoding = self.tokenizer( batch_encoding = self.tokenizer(
text, text,
truncation=True, truncation=True,
@ -451,7 +458,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
tokens = batch_encoding['input_ids'].to(self.device) tokens = batch_encoding['input_ids'].to(self.device)
z = self.transformer(input_ids=tokens, **kwargs) z = self.transformer(input_ids=tokens, **kwargs)
if kwargs.get('return_tokens', False): if should_return_tokens:
return z, tokens return z, tokens
else: else:
return z return z

View File

@ -1,6 +1,7 @@
import unittest import unittest
from ldm.invoke.prompt_parser import PromptParser, Blend, Conjunction, FlattenedPrompt, CrossAttentionControlSubstitute from ldm.invoke.prompt_parser import PromptParser, Blend, Conjunction, FlattenedPrompt, CrossAttentionControlSubstitute, \
Fragment
def parse_prompt(prompt_string): def parse_prompt(prompt_string):
@ -135,7 +136,7 @@ class PromptParserTestCase(unittest.TestCase):
def test_cross_attention_control(self): def test_cross_attention_control(self):
fire_flames_to_trees = Conjunction([FlattenedPrompt([('fire', 1.0), \ fire_flames_to_trees = Conjunction([FlattenedPrompt([('fire', 1.0), \
CrossAttentionControlSubstitute('flames', 'trees')])]) CrossAttentionControlSubstitute([Fragment('flames', 1)], [Fragment('trees', 1)])])])
self.assertEqual(fire_flames_to_trees, parse_prompt('fire "flames".swap(trees)')) self.assertEqual(fire_flames_to_trees, parse_prompt('fire "flames".swap(trees)'))
self.assertEqual(fire_flames_to_trees, parse_prompt('fire (flames).swap(trees)')) self.assertEqual(fire_flames_to_trees, parse_prompt('fire (flames).swap(trees)'))
self.assertEqual(fire_flames_to_trees, parse_prompt('fire ("flames").swap(trees)')) self.assertEqual(fire_flames_to_trees, parse_prompt('fire ("flames").swap(trees)'))
@ -144,13 +145,13 @@ class PromptParserTestCase(unittest.TestCase):
self.assertEqual(fire_flames_to_trees, parse_prompt('fire ("flames").swap("trees")')) self.assertEqual(fire_flames_to_trees, parse_prompt('fire ("flames").swap("trees")'))
fire_flames_to_trees_and_houses = Conjunction([FlattenedPrompt([('fire', 1.0), \ fire_flames_to_trees_and_houses = Conjunction([FlattenedPrompt([('fire', 1.0), \
CrossAttentionControlSubstitute('flames', 'trees and houses')])]) CrossAttentionControlSubstitute([Fragment('flames', 1)], [Fragment('trees and houses', 1)])])])
self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire ("flames").swap("trees and houses")')) self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire ("flames").swap("trees and houses")'))
self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire (flames).swap("trees and houses")')) self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire (flames).swap("trees and houses")'))
self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire "flames".swap("trees and houses")')) self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire "flames".swap("trees and houses")'))
trees_and_houses_to_flames = Conjunction([FlattenedPrompt([('fire', 1.0), \ trees_and_houses_to_flames = Conjunction([FlattenedPrompt([('fire', 1.0), \
CrossAttentionControlSubstitute('trees and houses', 'flames')])]) CrossAttentionControlSubstitute([Fragment('trees and houses', 1)], [Fragment('flames',1)])])])
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire ("trees and houses").swap("flames")')) self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire ("trees and houses").swap("flames")'))
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire (trees and houses).swap("flames")')) self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire (trees and houses).swap("flames")'))
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire "trees and houses".swap("flames")')) self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire "trees and houses".swap("flames")'))
@ -159,14 +160,46 @@ class PromptParserTestCase(unittest.TestCase):
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire "trees and houses".swap(flames)')) self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire "trees and houses".swap(flames)'))
flames_to_trees_fire = Conjunction([FlattenedPrompt([ flames_to_trees_fire = Conjunction([FlattenedPrompt([
CrossAttentionControlSubstitute('flames', 'trees'), CrossAttentionControlSubstitute([Fragment('flames',1)], [Fragment('trees',1)]),
(', fire', 1.0)])]) (', fire', 1.0)])])
self.assertEqual(flames_to_trees_fire, parse_prompt('"flames".swap(trees), fire'))
self.assertEqual(flames_to_trees_fire, parse_prompt('(flames).swap(trees), fire '))
self.assertEqual(flames_to_trees_fire, parse_prompt('("flames").swap(trees), fire '))
self.assertEqual(flames_to_trees_fire, parse_prompt('"flames".swap("trees"), fire')) self.assertEqual(flames_to_trees_fire, parse_prompt('"flames".swap("trees"), fire'))
self.assertEqual(flames_to_trees_fire, parse_prompt('(flames).swap("trees"), fire')) self.assertEqual(flames_to_trees_fire, parse_prompt('(flames).swap("trees"), fire'))
self.assertEqual(flames_to_trees_fire, parse_prompt('("flames").swap("trees"), fire')) self.assertEqual(flames_to_trees_fire, parse_prompt('("flames").swap("trees"), fire'))
self.assertEqual(flames_to_trees_fire, parse_prompt('"flames".swap(trees), fire'))
self.assertEqual(flames_to_trees_fire, parse_prompt('(flames).swap(trees), fire '))
self.assertEqual(flames_to_trees_fire, parse_prompt('("flames").swap(trees), fire '))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]),
parse_prompt('a forest landscape "".swap("in winter")'))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
CrossAttentionControlSubstitute([Fragment(' ',1)], [Fragment('in winter',1)])])]),
parse_prompt('a forest landscape " ".swap("in winter")'))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]),
parse_prompt('a forest landscape "in winter".swap("")'))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]),
parse_prompt('a forest landscape "in winter".swap()'))
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment(' ',1)])])]),
parse_prompt('a forest landscape "in winter".swap(" ")'))
def test_cross_attention_control_with_attention(self):
flames_to_trees_fire = Conjunction([FlattenedPrompt([
CrossAttentionControlSubstitute([Fragment('flames',0.5)], [Fragment('trees',0.7)]),
Fragment(',', 1), Fragment('fire', 2.0)])])
self.assertEqual(flames_to_trees_fire, parse_prompt('"0.5(flames)".swap("0.7(trees)"), 2.0(fire)'))
flames_to_trees_fire = Conjunction([FlattenedPrompt([
CrossAttentionControlSubstitute([Fragment('fire',0.5), Fragment('flames',0.25)], [Fragment('trees',0.7)]),
Fragment(',', 1), Fragment('fire', 2.0)])])
self.assertEqual(flames_to_trees_fire, parse_prompt('"0.5(fire 0.5(flames))".swap("0.7(trees)"), 2.0(fire)'))
flames_to_trees_fire = Conjunction([FlattenedPrompt([
CrossAttentionControlSubstitute([Fragment('fire',0.5), Fragment('flames',0.25)], [Fragment('trees',0.7), Fragment('houses', 1)]),
Fragment(',', 1), Fragment('fire', 2.0)])])
self.assertEqual(flames_to_trees_fire, parse_prompt('"0.5(fire 0.5(flames))".swap("0.7(trees) houses"), 2.0(fire)'))
if __name__ == '__main__': if __name__ == '__main__':