diff --git a/ldm/invoke/conditioning.py b/ldm/invoke/conditioning.py index b7c8e55e66..fb6d8d443e 100644 --- a/ldm/invoke/conditioning.py +++ b/ldm/invoke/conditioning.py @@ -4,7 +4,7 @@ weighted subprompts. Useful function exports: -get_uc_and_c() get the conditioned and unconditioned latent +get_uc_and_c_and_ec() get the conditioned and unconditioned latent, and edited conditioning if we're doing cross-attention control split_weighted_subpromopts() split subprompts, normalize and weight them log_tokenization() print out colour-coded tokens and warn if truncated @@ -39,8 +39,10 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n pp = PromptParser() - parsed_prompt: Union[FlattenedPrompt, Blend] = pp.parse(prompt_string_cleaned) - parsed_negative_prompt: FlattenedPrompt = pp.parse(unconditioned_words) + # we don't support conjunctions for now + parsed_prompt: Union[FlattenedPrompt, Blend] = pp.parse(prompt_string_cleaned).prompts[0] + parsed_negative_prompt: FlattenedPrompt = pp.parse(unconditioned_words).prompts[0] + print("parsed prompt to", parsed_prompt) conditioning = None edited_conditioning = None @@ -50,7 +52,7 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n blend: Blend = parsed_prompt embeddings_to_blend = None for flattened_prompt in blend.prompts: - this_embedding = make_embeddings_for_flattened_prompt(model, flattened_prompt) + this_embedding = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt) embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat( (embeddings_to_blend, this_embedding)) conditioning, _ = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0), @@ -72,16 +74,16 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n # regular fragment original_prompt.append(fragment) edited_prompt.append(fragment) - original_embeddings, original_tokens = make_embeddings_for_flattened_prompt(model, original_prompt) - edited_embeddings, edited_tokens = make_embeddings_for_flattened_prompt(model, edited_prompt) + original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, original_prompt) + edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, edited_prompt) conditioning = original_embeddings edited_conditioning = edited_embeddings edit_opcodes = build_token_edit_opcodes(original_tokens, edited_tokens) else: - conditioning, _ = make_embeddings_for_flattened_prompt(model, flattened_prompt) + conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt) - unconditioning = make_embeddings_for_flattened_prompt(parsed_negative_prompt) + unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, parsed_negative_prompt) return (unconditioning, conditioning, edited_conditioning, edit_opcodes) @@ -91,11 +93,11 @@ def build_token_edit_opcodes(original_tokens, edited_tokens): return SequenceMatcher(None, original_tokens, edited_tokens).get_opcodes() -def make_embeddings_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt): +def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt): if type(flattened_prompt) is not FlattenedPrompt: raise f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead" - fragments = [x[0] for x in flattened_prompt.children] - embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True) + fragments = [x.text for x in flattened_prompt.children] + embeddings, tokens = model.get_learned_conditioning([' '.join(fragments)], return_tokens=True) return embeddings, tokens diff --git a/ldm/invoke/generator/img2img.py b/ldm/invoke/generator/img2img.py index 0a12bd90e5..6fa0d0c6dd 100644 --- a/ldm/invoke/generator/img2img.py +++ b/ldm/invoke/generator/img2img.py @@ -34,7 +34,7 @@ class Img2Img(Generator): t_enc = int(strength * steps) uc, c, ec, edit_opcodes = conditioning - structured_conditioning = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes) + extra_conditioning_info = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes) def make_image(x_T): # encode (scaled latent) @@ -52,7 +52,7 @@ class Img2Img(Generator): unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc, init_latent = self.init_latent, - structured_conditioning = structured_conditioning + extra_conditioning_info = extra_conditioning_info # changes how noising is performed in ksampler ) diff --git a/ldm/invoke/generator/txt2img.py b/ldm/invoke/generator/txt2img.py index 6e158562c5..657cccc592 100644 --- a/ldm/invoke/generator/txt2img.py +++ b/ldm/invoke/generator/txt2img.py @@ -22,7 +22,7 @@ class Txt2Img(Generator): """ self.perlin = perlin uc, c, ec, edit_opcodes = conditioning - structured_conditioning = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes) + extra_conditioning_info = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes) @torch.no_grad() def make_image(x_T): @@ -46,7 +46,7 @@ class Txt2Img(Generator): verbose = False, unconditional_guidance_scale = cfg_scale, unconditional_conditioning = uc, - structured_conditioning = structured_conditioning, + extra_conditioning_info = extra_conditioning_info, eta = ddim_eta, img_callback = step_callback, threshold = threshold, diff --git a/ldm/invoke/generator/txt2img2img.py b/ldm/invoke/generator/txt2img2img.py index 52a14aae74..64d0468418 100644 --- a/ldm/invoke/generator/txt2img2img.py +++ b/ldm/invoke/generator/txt2img2img.py @@ -24,7 +24,7 @@ class Txt2Img2Img(Generator): kwargs are 'width' and 'height' """ uc, c, ec, edit_opcodes = conditioning - structured_conditioning = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes) + extra_conditioning_info = InvokeAIDiffuserComponent.StructuredConditioning(edited_conditioning=ec, edit_opcodes=edit_opcodes) @torch.no_grad() def make_image(x_T): @@ -63,7 +63,7 @@ class Txt2Img2Img(Generator): unconditional_conditioning = uc, eta = ddim_eta, img_callback = step_callback, - structured_conditioning = structured_conditioning + extra_conditioning_info = extra_conditioning_info ) print( @@ -97,7 +97,7 @@ class Txt2Img2Img(Generator): img_callback = step_callback, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc, - structured_conditioning = structured_conditioning + extra_conditioning_info = extra_conditioning_info ) if self.free_gpu_mem: diff --git a/ldm/invoke/prompt_parser.py b/ldm/invoke/prompt_parser.py index 9dd0f80ade..c13175a488 100644 --- a/ldm/invoke/prompt_parser.py +++ b/ldm/invoke/prompt_parser.py @@ -7,7 +7,7 @@ class Prompt(): def __init__(self, parts: list): for c in parts: - if not issubclass(type(c), BaseFragment): + if type(c) is not Attention and not issubclass(type(c), BaseFragment): raise PromptParser.ParsingException(f"Prompt cannot contain {type(c)}, only {BaseFragment.__subclasses__()} are allowed") self.children = parts def __repr__(self): @@ -56,6 +56,17 @@ class Fragment(BaseFragment): and other.text == self.text \ and other.weight == self.weight +class Attention(): + def __init__(self, weight: float, children: list): + self.weight = weight + self.children = children + #print(f"A: requested attention '{children}' to {weight}") + + def __repr__(self): + return f"Attention:'{self.children}' @ {self.weight}" + def __eq__(self, other): + return type(other) is Attention and other.weight == self.weight and other.fragment == self.fragment + class CrossAttentionControlledFragment(BaseFragment): pass @@ -65,7 +76,7 @@ class CrossAttentionControlSubstitute(CrossAttentionControlledFragment): self.edited = edited def __repr__(self): - return f"CrossAttentionControlSubstitute:('{self.original}'->'{self.edited}')" + return f"CrossAttentionControlSubstitute:({self.original}->{self.edited})" def __eq__(self, other): return type(other) is CrossAttentionControlSubstitute \ and other.original == self.original \ @@ -137,7 +148,7 @@ class PromptParser(): self.root = self.build_parser_logic() - def parse(self, prompt: str) -> [list]: + def parse(self, prompt: str) -> Conjunction: ''' :param prompt: The prompt string to parse :return: a tuple @@ -181,13 +192,17 @@ class PromptParser(): for x in node: results = flatten_internal(x, weight_scale, results, prefix+'pr') #print(prefix, " ParseResults expanded, results is now", results) - elif issubclass(type(node), BaseFragment): - results.append(node) - #elif type(node) is Attention: - # #if node.weight < 1: - # # todo: inject a blend when flattening attention with weight <1" - # for c in node.children: - # results = flatten_internal(c, weight_scale*node.weight, results, prefix+' ') + elif type(node) is Attention: + # if node.weight < 1: + # todo: inject a blend when flattening attention with weight <1" + for c in node.children: + results = flatten_internal(c, weight_scale * node.weight, results, prefix + ' ') + elif type(node) is Fragment: + results += [Fragment(node.text, node.weight*weight_scale)] + elif type(node) is CrossAttentionControlSubstitute: + original = flatten_internal(node.original, weight_scale, [], ' CAo ') + edited = flatten_internal(node.edited, weight_scale, [], ' CAe ') + results += [CrossAttentionControlSubstitute(original, edited)] elif type(node) is Blend: flattened_subprompts = [] #print(" flattening blend with prompts", node.prompts, "weights", node.weights) @@ -204,7 +219,7 @@ class PromptParser(): #print(prefix + "after flattening Prompt, results is", results) else: raise PromptParser.ParsingException(f"unhandled node type {type(node)} when flattening {node}") - #print(prefix + "-> after flattening", type(node), "results is", results) + print(prefix + "-> after flattening", type(node), "results is", results) return results #print("flattening", root) @@ -239,32 +254,83 @@ class PromptParser(): else: raise PromptParser.ParsingException("Cannot make fragment from " + str(x)) + # attention control of the form +(phrase) / -(phrase) / (phrase) + # phrase can be multiple words, can have multiple +/- signs to increase the effect or type a floating point or integer weight + attention = pp.Forward() + attention_head = (number | pp.Word('+') | pp.Word('-'))\ + .set_name("attention_head")\ + .set_debug(False) + fragment_inside_attention = pp.CharsNotIn(SPACE_CHARS+'()')\ + .set_parse_action(make_fragment)\ + .set_name("fragment_inside_attention")\ + .set_debug(False) + attention_with_parens = pp.Forward() + attention_with_parens_body = pp.nested_expr(content=pp.delimited_list((attention_with_parens | fragment_inside_attention), delim=SPACE_CHARS)) + attention_with_parens << (attention_head + attention_with_parens_body) + + def make_attention(x): + # print("making Attention from parsing with args", x0, x1) + weight = 1 + # number(str) + if type(x[0]) is float or type(x[0]) is int: + weight = float(x[0]) + # +(str) or -(str) or +str or -str + elif type(x[0]) is str: + base = self.attention_plus_base if x[0][0] == '+' else self.attention_minus_base + weight = pow(base, len(x[0])) + # print("Making attention with children of type", [str(type(x)) for x in x1]) + return Attention(weight=weight, children=x[1]) + + attention_with_parens.set_parse_action(make_attention)\ + .set_name("attention_with_parens")\ + .set_debug(False) + + # attention control of the form ++word --word (no parens) + attention_without_parens = ( + (pp.Word('+') | pp.Word('-')) + + pp.CharsNotIn(SPACE_CHARS+'()').set_parse_action(lambda x: [[make_fragment(x)]]) + )\ + .set_name("attention_without_parens")\ + .set_debug(False) + attention_without_parens.set_parse_action(make_attention) + + attention << (attention_with_parens | attention_without_parens)\ + .set_name("attention")\ + .set_debug(False) + + # cross-attention control + empty_string = ((lparen + rparen) | + pp.Literal('""').suppress() | + (lparen + pp.Literal('""').suppress() + rparen) + ).set_parse_action(lambda x: Fragment("")) original_words = ( - (lparen + pp.Literal('"').suppress() + pp.CharsNotIn('"') + pp.Literal('"').suppress() + rparen).set_name('term1').set_debug(False) | - (pp.Literal('"').suppress() + pp.CharsNotIn('"') + pp.Literal('"').suppress()).set_name('term2').set_debug(False) | - (lparen + pp.CharsNotIn(')') + rparen).set_name('term3').set_debug(False) + (lparen + pp.Literal('"').suppress() + (pp.OneOrMore(attention) | pp.CharsNotIn('"').set_parse_action(make_fragment)) + pp.Literal('"').suppress() + rparen).set_name('term1').set_debug(False) | + (pp.Literal('"').suppress() + (pp.OneOrMore(attention) | pp.CharsNotIn('"').set_parse_action(make_fragment)) + pp.Literal('"').suppress()).set_name('term2').set_debug(False) | + (lparen + (pp.OneOrMore(attention) | pp.CharsNotIn(')').set_parse_action(make_fragment)) + rparen).set_name('term3').set_debug(False) ).set_name('original_words') edited_words = ( - (pp.Literal('"').suppress() + pp.CharsNotIn('"') + pp.Literal('"').suppress()).set_name('termA').set_debug(False) | - pp.CharsNotIn(')').set_name('termB').set_debug(False) + (pp.Literal('"').suppress() + (pp.OneOrMore(attention) | pp.CharsNotIn('"').set_parse_action(make_fragment)) + pp.Literal('"').suppress()).set_name('termA').set_debug(False) | + pp.Literal('""').suppress().set_parse_action(lambda x: Fragment("")) | + (pp.OneOrMore(attention) | pp.CharsNotIn(')').set_parse_action(make_fragment)).set_name('termB').set_debug(True) ).set_name('edited_words') - cross_attention_substitute = original_words + \ + cross_attention_substitute = (empty_string | original_words) + \ pp.Literal(".swap").suppress() + \ - lparen + edited_words + rparen + (empty_string | (lparen + edited_words + rparen) + ) cross_attention_substitute.set_name('cross_attention_substitute') def make_cross_attention_substitute(x): #print("making cacs for", x) - return CrossAttentionControlSubstitute(x[0], x[1]) + cacs = CrossAttentionControlSubstitute(x[0], x[1]) #print("made", cacs) - #return cacs + return cacs cross_attention_substitute.set_parse_action(make_cross_attention_substitute) # simple fragments of text prompt_part << (cross_attention_substitute - #| attention + | attention | word ) prompt_part.set_debug(False) diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py index 0ab6911247..98219fb62e 100644 --- a/ldm/models/diffusion/ddim.py +++ b/ldm/models/diffusion/ddim.py @@ -16,10 +16,10 @@ class DDIMSampler(Sampler): def prepare_to_sample(self, t_enc, **kwargs): super().prepare_to_sample(t_enc, **kwargs) - structured_conditioning = kwargs.get('structured_conditioning', None) + extra_conditioning_info = kwargs.get('extra_conditioning_info', None) - if structured_conditioning is not None and structured_conditioning.wants_cross_attention_control: - self.invokeai_diffuser.setup_cross_attention_control(structured_conditioning) + if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: + self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info) else: self.invokeai_diffuser.remove_cross_attention_control() diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index a8291e32c1..8c858757eb 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -34,10 +34,10 @@ class CFGDenoiser(nn.Module): def prepare_to_sample(self, t_enc, **kwargs): - structured_conditioning = kwargs.get('structured_conditioning', None) + extra_conditioning_info = kwargs.get('extra_conditioning_info', None) - if structured_conditioning is not None and structured_conditioning.wants_cross_attention_control: - self.invokeai_diffuser.setup_cross_attention_control(structured_conditioning) + if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: + self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info) else: self.invokeai_diffuser.remove_cross_attention_control() @@ -164,7 +164,7 @@ class KSampler(Sampler): log_every_t=100, unconditional_guidance_scale=1.0, unconditional_conditioning=None, - structured_conditioning=None, + extra_conditioning_info=None, threshold = 0, perlin = 0, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... @@ -197,7 +197,7 @@ class KSampler(Sampler): x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10)) - model_wrap_cfg.prepare_to_sample(S, structured_conditioning=structured_conditioning) + model_wrap_cfg.prepare_to_sample(S, extra_conditioning_info=extra_conditioning_info) extra_args = { 'cond': conditioning, 'uncond': unconditional_conditioning, @@ -224,7 +224,7 @@ class KSampler(Sampler): index, unconditional_guidance_scale=1.0, unconditional_conditioning=None, - structured_conditioning=None, + extra_conditioning_info=None, **kwargs, ): if self.model_wrap is None: @@ -250,7 +250,7 @@ class KSampler(Sampler): # so the actual formula for indexing into sigmas: # sigma_index = (steps-index) s_index = t_enc - index - 1 - self.model_wrap.prepare_to_sample(s_index, structured_conditioning=structured_conditioning) + self.model_wrap.prepare_to_sample(s_index, extra_conditioning_info=extra_conditioning_info) img = K.sampling.__dict__[f'_{self.schedule}']( self.model_wrap, img, diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py index 98975525ed..f58e2c3220 100644 --- a/ldm/models/diffusion/plms.py +++ b/ldm/models/diffusion/plms.py @@ -20,10 +20,10 @@ class PLMSSampler(Sampler): def prepare_to_sample(self, t_enc, **kwargs): super().prepare_to_sample(t_enc, **kwargs) - structured_conditioning = kwargs.get('structured_conditioning', None) + extra_conditioning_info = kwargs.get('extra_conditioning_info', None) - if structured_conditioning is not None and structured_conditioning.wants_cross_attention_control: - self.invokeai_diffuser.setup_cross_attention_control(structured_conditioning) + if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: + self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info) else: self.invokeai_diffuser.remove_cross_attention_control() diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index 8f4ad26119..18878af443 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -439,6 +439,13 @@ class FrozenCLIPEmbedder(AbstractEncoder): param.requires_grad = False def forward(self, text, **kwargs): + + should_return_tokens = False + if 'return_tokens' in kwargs: + should_return_tokens = kwargs.get('return_tokens', False) + # self.transformer doesn't like having extra kwargs + kwargs.pop('return_tokens') + batch_encoding = self.tokenizer( text, truncation=True, @@ -451,7 +458,7 @@ class FrozenCLIPEmbedder(AbstractEncoder): tokens = batch_encoding['input_ids'].to(self.device) z = self.transformer(input_ids=tokens, **kwargs) - if kwargs.get('return_tokens', False): + if should_return_tokens: return z, tokens else: return z diff --git a/tests/test_prompt_parser.py b/tests/test_prompt_parser.py index 2ef56c47ae..99f4db33a1 100644 --- a/tests/test_prompt_parser.py +++ b/tests/test_prompt_parser.py @@ -1,6 +1,7 @@ import unittest -from ldm.invoke.prompt_parser import PromptParser, Blend, Conjunction, FlattenedPrompt, CrossAttentionControlSubstitute +from ldm.invoke.prompt_parser import PromptParser, Blend, Conjunction, FlattenedPrompt, CrossAttentionControlSubstitute, \ + Fragment def parse_prompt(prompt_string): @@ -135,7 +136,7 @@ class PromptParserTestCase(unittest.TestCase): def test_cross_attention_control(self): fire_flames_to_trees = Conjunction([FlattenedPrompt([('fire', 1.0), \ - CrossAttentionControlSubstitute('flames', 'trees')])]) + CrossAttentionControlSubstitute([Fragment('flames', 1)], [Fragment('trees', 1)])])]) self.assertEqual(fire_flames_to_trees, parse_prompt('fire "flames".swap(trees)')) self.assertEqual(fire_flames_to_trees, parse_prompt('fire (flames).swap(trees)')) self.assertEqual(fire_flames_to_trees, parse_prompt('fire ("flames").swap(trees)')) @@ -144,13 +145,13 @@ class PromptParserTestCase(unittest.TestCase): self.assertEqual(fire_flames_to_trees, parse_prompt('fire ("flames").swap("trees")')) fire_flames_to_trees_and_houses = Conjunction([FlattenedPrompt([('fire', 1.0), \ - CrossAttentionControlSubstitute('flames', 'trees and houses')])]) + CrossAttentionControlSubstitute([Fragment('flames', 1)], [Fragment('trees and houses', 1)])])]) self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire ("flames").swap("trees and houses")')) self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire (flames).swap("trees and houses")')) self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire "flames".swap("trees and houses")')) trees_and_houses_to_flames = Conjunction([FlattenedPrompt([('fire', 1.0), \ - CrossAttentionControlSubstitute('trees and houses', 'flames')])]) + CrossAttentionControlSubstitute([Fragment('trees and houses', 1)], [Fragment('flames',1)])])]) self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire ("trees and houses").swap("flames")')) self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire (trees and houses).swap("flames")')) self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire "trees and houses".swap("flames")')) @@ -159,14 +160,46 @@ class PromptParserTestCase(unittest.TestCase): self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire "trees and houses".swap(flames)')) flames_to_trees_fire = Conjunction([FlattenedPrompt([ - CrossAttentionControlSubstitute('flames', 'trees'), + CrossAttentionControlSubstitute([Fragment('flames',1)], [Fragment('trees',1)]), (', fire', 1.0)])]) - self.assertEqual(flames_to_trees_fire, parse_prompt('"flames".swap(trees), fire')) - self.assertEqual(flames_to_trees_fire, parse_prompt('(flames).swap(trees), fire ')) - self.assertEqual(flames_to_trees_fire, parse_prompt('("flames").swap(trees), fire ')) self.assertEqual(flames_to_trees_fire, parse_prompt('"flames".swap("trees"), fire')) self.assertEqual(flames_to_trees_fire, parse_prompt('(flames).swap("trees"), fire')) self.assertEqual(flames_to_trees_fire, parse_prompt('("flames").swap("trees"), fire')) + self.assertEqual(flames_to_trees_fire, parse_prompt('"flames".swap(trees), fire')) + self.assertEqual(flames_to_trees_fire, parse_prompt('(flames).swap(trees), fire ')) + self.assertEqual(flames_to_trees_fire, parse_prompt('("flames").swap(trees), fire ')) + + + self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1), + CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]), + parse_prompt('a forest landscape "".swap("in winter")')) + self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1), + CrossAttentionControlSubstitute([Fragment(' ',1)], [Fragment('in winter',1)])])]), + parse_prompt('a forest landscape " ".swap("in winter")')) + + self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1), + CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]), + parse_prompt('a forest landscape "in winter".swap("")')) + self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1), + CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]), + parse_prompt('a forest landscape "in winter".swap()')) + self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1), + CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment(' ',1)])])]), + parse_prompt('a forest landscape "in winter".swap(" ")')) + + def test_cross_attention_control_with_attention(self): + flames_to_trees_fire = Conjunction([FlattenedPrompt([ + CrossAttentionControlSubstitute([Fragment('flames',0.5)], [Fragment('trees',0.7)]), + Fragment(',', 1), Fragment('fire', 2.0)])]) + self.assertEqual(flames_to_trees_fire, parse_prompt('"0.5(flames)".swap("0.7(trees)"), 2.0(fire)')) + flames_to_trees_fire = Conjunction([FlattenedPrompt([ + CrossAttentionControlSubstitute([Fragment('fire',0.5), Fragment('flames',0.25)], [Fragment('trees',0.7)]), + Fragment(',', 1), Fragment('fire', 2.0)])]) + self.assertEqual(flames_to_trees_fire, parse_prompt('"0.5(fire 0.5(flames))".swap("0.7(trees)"), 2.0(fire)')) + flames_to_trees_fire = Conjunction([FlattenedPrompt([ + CrossAttentionControlSubstitute([Fragment('fire',0.5), Fragment('flames',0.25)], [Fragment('trees',0.7), Fragment('houses', 1)]), + Fragment(',', 1), Fragment('fire', 2.0)])]) + self.assertEqual(flames_to_trees_fire, parse_prompt('"0.5(fire 0.5(flames))".swap("0.7(trees) houses"), 2.0(fire)')) if __name__ == '__main__':