From 42883545f9d9fb308f7eb160a52028e21b6ae4b9 Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Thu, 20 Oct 2022 01:42:04 +0200 Subject: [PATCH] add prompt language support for cross-attention .swap --- ldm/generate.py | 2 +- ldm/invoke/conditioning.py | 110 ++++++----- ldm/invoke/prompt_parser.py | 326 ++++++++++++++++++++++++++++++++ ldm/models/diffusion/ddpm.py | 8 +- ldm/modules/encoders/modules.py | 16 +- tests/test_prompt_parser.py | 173 +++++++++++++++++ 6 files changed, 585 insertions(+), 50 deletions(-) create mode 100644 ldm/invoke/prompt_parser.py create mode 100644 tests/test_prompt_parser.py diff --git a/ldm/generate.py b/ldm/generate.py index 45ed2e73d1..39bcc28162 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -35,7 +35,7 @@ from ldm.invoke.devices import choose_torch_device, choose_precision from ldm.invoke.conditioning import get_uc_and_c_and_ec from ldm.invoke.model_cache import ModelCache from ldm.invoke.seamless import configure_model_padding -from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale +#from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale def fix_func(orig): if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): diff --git a/ldm/invoke/conditioning.py b/ldm/invoke/conditioning.py index 8c8f5eeb01..b7c8e55e66 100644 --- a/ldm/invoke/conditioning.py +++ b/ldm/invoke/conditioning.py @@ -11,71 +11,93 @@ log_tokenization() print out colour-coded tokens and warn if trunca ''' import re from difflib import SequenceMatcher +from typing import Union import torch -def get_uc_and_c_and_ec(prompt, model, log_tokens=False, skip_normalize=False): +from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \ + CrossAttentionControlledFragment, CrossAttentionControlSubstitute, CrossAttentionControlAppend +from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder + + +def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_normalize=False): + # Extract Unconditioned Words From Prompt unconditioned_words = '' unconditional_regex = r'\[(.*?)\]' - unconditionals = re.findall(unconditional_regex, prompt) + unconditionals = re.findall(unconditional_regex, prompt_string_uncleaned) if len(unconditionals) > 0: unconditioned_words = ' '.join(unconditionals) # Remove Unconditioned Words From Prompt unconditional_regex_compile = re.compile(unconditional_regex) - clean_prompt = unconditional_regex_compile.sub(' ', prompt) - prompt = re.sub(' +', ' ', clean_prompt) + clean_prompt = unconditional_regex_compile.sub(' ', prompt_string_uncleaned) + prompt_string_cleaned = re.sub(' +', ' ', clean_prompt) + else: + prompt_string_cleaned = prompt_string_uncleaned - edited_words = None - edited_regex = r'\{(.*?)\}' - edited = re.findall(edited_regex, prompt) - if len(edited) > 0: - edited_words = ' '.join(edited) - edited_regex_compile = re.compile(edited_regex) - clean_prompt = edited_regex_compile.sub(' ', prompt) - prompt = re.sub(' +', ' ', clean_prompt) + pp = PromptParser() - # get weighted sub-prompts - weighted_subprompts = split_weighted_subprompts( - prompt, skip_normalize - ) + parsed_prompt: Union[FlattenedPrompt, Blend] = pp.parse(prompt_string_cleaned) + parsed_negative_prompt: FlattenedPrompt = pp.parse(unconditioned_words) - ec = None + conditioning = None + edited_conditioning = None edit_opcodes = None - uc, _ = model.get_learned_conditioning([unconditioned_words]) + if parsed_prompt is Blend: + blend: Blend = parsed_prompt + embeddings_to_blend = None + for flattened_prompt in blend.prompts: + this_embedding = make_embeddings_for_flattened_prompt(model, flattened_prompt) + embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat( + (embeddings_to_blend, this_embedding)) + conditioning, _ = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0), + blend.weights, + normalize=blend.normalize_weights) + else: + flattened_prompt: FlattenedPrompt = parsed_prompt + wants_cross_attention_control = any([issubclass(type(x), CrossAttentionControlledFragment) for x in flattened_prompt.children]) + if wants_cross_attention_control: + original_prompt = FlattenedPrompt() + edited_prompt = FlattenedPrompt() + for fragment in flattened_prompt.children: + if type(fragment) is CrossAttentionControlSubstitute: + original_prompt.append(fragment.original_fragment) + edited_prompt.append(fragment.edited_fragment) + elif type(fragment) is CrossAttentionControlAppend: + edited_prompt.append(fragment.fragment) + else: + # regular fragment + original_prompt.append(fragment) + edited_prompt.append(fragment) + original_embeddings, original_tokens = make_embeddings_for_flattened_prompt(model, original_prompt) + edited_embeddings, edited_tokens = make_embeddings_for_flattened_prompt(model, edited_prompt) - if len(weighted_subprompts) > 1: - # i dont know if this is correct.. but it works - c = torch.zeros_like(uc) - # normalize each "sub prompt" and add it - for subprompt, weight in weighted_subprompts: - log_tokenization(subprompt, model, log_tokens, weight) - subprompt_embeddings, _ = model.get_learned_conditioning([subprompt]) - c = torch.add( - c, - subprompt_embeddings, - alpha=weight, - ) - if edited_words is not None: - print("can't do cross-attention control with blends just yet, ignoring edits") - else: # just standard 1 prompt - log_tokenization(prompt, model, log_tokens, 1) - c, c_tokens = model.get_learned_conditioning([prompt]) - if edited_words is not None: - ec, ec_tokens = model.get_learned_conditioning([edited_words]) - edit_opcodes = build_token_edit_opcodes(c_tokens, ec_tokens) + conditioning = original_embeddings + edited_conditioning = edited_embeddings + edit_opcodes = build_token_edit_opcodes(original_tokens, edited_tokens) + else: + conditioning, _ = make_embeddings_for_flattened_prompt(model, flattened_prompt) - return (uc, c, ec, edit_opcodes) + unconditioning = make_embeddings_for_flattened_prompt(parsed_negative_prompt) + return (unconditioning, conditioning, edited_conditioning, edit_opcodes) -def build_token_edit_opcodes(c_tokens, ec_tokens): - tokens = c_tokens.cpu().numpy()[0] - tokens_edit = ec_tokens.cpu().numpy()[0] - opcodes = SequenceMatcher(None, tokens, tokens_edit).get_opcodes() - return opcodes +def build_token_edit_opcodes(original_tokens, edited_tokens): + original_tokens = original_tokens.cpu().numpy()[0] + edited_tokens = edited_tokens.cpu().numpy()[0] + + return SequenceMatcher(None, original_tokens, edited_tokens).get_opcodes() + +def make_embeddings_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt): + if type(flattened_prompt) is not FlattenedPrompt: + raise f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead" + fragments = [x[0] for x in flattened_prompt.children] + embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True) + return embeddings, tokens + def split_weighted_subprompts(text, skip_normalize=False)->list: diff --git a/ldm/invoke/prompt_parser.py b/ldm/invoke/prompt_parser.py new file mode 100644 index 0000000000..9dd0f80ade --- /dev/null +++ b/ldm/invoke/prompt_parser.py @@ -0,0 +1,326 @@ +import pyparsing +import pyparsing as pp +from pyparsing import original_text_for + + +class Prompt(): + + def __init__(self, parts: list): + for c in parts: + if not issubclass(type(c), BaseFragment): + raise PromptParser.ParsingException(f"Prompt cannot contain {type(c)}, only {BaseFragment.__subclasses__()} are allowed") + self.children = parts + def __repr__(self): + return f"Prompt:{self.children}" + def __eq__(self, other): + return type(other) is Prompt and other.children == self.children + +class FlattenedPrompt(): + def __init__(self, parts: list): + # verify type correctness + parts_converted = [] + for part in parts: + if issubclass(type(part), BaseFragment): + parts_converted.append(part) + elif type(part) is tuple: + # upgrade tuples to Fragments + if type(part[0]) is not str or (type(part[1]) is not float and type(part[1]) is not int): + raise PromptParser.ParsingException( + f"FlattenedPrompt cannot contain {part}, only Fragments or (str, float) tuples are allowed") + parts_converted.append(Fragment(part[0], part[1])) + else: + raise PromptParser.ParsingException( + f"FlattenedPrompt cannot contain {part}, only Fragments or (str, float) tuples are allowed") + # all looks good + self.children = parts_converted + + def __repr__(self): + return f"FlattenedPrompt:{self.children}" + def __eq__(self, other): + return type(other) is FlattenedPrompt and other.children == self.children + +# abstract base class for Fragments +class BaseFragment: + pass + +class Fragment(BaseFragment): + def __init__(self, text: str, weight: float=1): + assert(type(text) is str) + self.text = text + self.weight = float(weight) + + def __repr__(self): + return "Fragment:'"+self.text+"'@"+str(self.weight) + def __eq__(self, other): + return type(other) is Fragment \ + and other.text == self.text \ + and other.weight == self.weight + +class CrossAttentionControlledFragment(BaseFragment): + pass + +class CrossAttentionControlSubstitute(CrossAttentionControlledFragment): + def __init__(self, original: Fragment, edited: Fragment): + self.original = original + self.edited = edited + + def __repr__(self): + return f"CrossAttentionControlSubstitute:('{self.original}'->'{self.edited}')" + def __eq__(self, other): + return type(other) is CrossAttentionControlSubstitute \ + and other.original == self.original \ + and other.edited == self.edited + +class CrossAttentionControlAppend(CrossAttentionControlledFragment): + def __init__(self, fragment: Fragment): + self.fragment = fragment + def __repr__(self): + return "CrossAttentionControlAppend:",self.fragment + def __eq__(self, other): + return type(other) is CrossAttentionControlAppend \ + and other.fragment == self.fragment + + + +class Conjunction(): + def __init__(self, prompts: list, weights: list = None): + # force everything to be a Prompt + #print("making conjunction with", parts) + self.prompts = [x if (type(x) is Prompt + or type(x) is Blend + or type(x) is FlattenedPrompt) + else Prompt(x) for x in prompts] + self.weights = [1.0]*len(self.prompts) if weights is None else list(weights) + if len(self.weights) != len(self.prompts): + raise PromptParser.ParsingException(f"while parsing Conjunction: mismatched parts/weights counts {prompts}, {weights}") + self.type = 'AND' + + def __repr__(self): + return f"Conjunction:{self.prompts} | weights {self.weights}" + def __eq__(self, other): + return type(other) is Conjunction \ + and other.prompts == self.prompts \ + and other.weights == self.weights + + +class Blend(): + def __init__(self, prompts: list, weights: list[float], normalize_weights: bool=True): + #print("making Blend with prompts", prompts, "and weights", weights) + if len(prompts) != len(weights): + raise PromptParser.ParsingException(f"while parsing Blend: mismatched prompts/weights counts {prompts}, {weights}") + for c in prompts: + if type(c) is not Prompt and type(c) is not FlattenedPrompt: + raise(PromptParser.ParsingException(f"{type(c)} cannot be added to a Blend, only Prompts or FlattenedPrompts")) + # upcast all lists to Prompt objects + self.prompts = [x if (type(x) is Prompt or type(x) is FlattenedPrompt) + else Prompt(x) for x in prompts] + self.prompts = prompts + self.weights = weights + self.normalize_weights = normalize_weights + + def __repr__(self): + return f"Blend:{self.prompts} | weights {self.weights}" + def __eq__(self, other): + return other.__repr__() == self.__repr__() + + +class PromptParser(): + + class ParsingException(Exception): + pass + + def __init__(self, attention_plus_base=1.1, attention_minus_base=0.9): + + self.attention_plus_base = attention_plus_base + self.attention_minus_base = attention_minus_base + + self.root = self.build_parser_logic() + + + def parse(self, prompt: str) -> [list]: + ''' + :param prompt: The prompt string to parse + :return: a tuple + ''' + #print(f"!!parsing '{prompt}'") + + if len(prompt.strip()) == 0: + return Conjunction(prompts=[FlattenedPrompt([('', 1.0)])], weights=[1.0]) + + root = self.root.parse_string(prompt) + #print(f"'{prompt}' parsed to root", root) + #fused = fuse_fragments(parts) + #print("fused to", fused) + + return self.flatten(root[0]) + + def flatten(self, root: Conjunction): + + def fuse_fragments(items): + # print("fusing fragments in ", items) + result = [] + for x in items: + if issubclass(type(x), CrossAttentionControlledFragment): + result.append(x) + else: + last_weight = result[-1].weight \ + if (len(result) > 0 and not issubclass(type(result[-1]), CrossAttentionControlledFragment)) \ + else None + this_text = x.text + this_weight = x.weight + if last_weight is not None and last_weight == this_weight: + last_text = result[-1].text + result[-1] = Fragment(last_text + ' ' + this_text, last_weight) + else: + result.append(x) + return result + + def flatten_internal(node, weight_scale, results, prefix): + #print(prefix + "flattening", node, "...") + if type(node) is pp.ParseResults: + for x in node: + results = flatten_internal(x, weight_scale, results, prefix+'pr') + #print(prefix, " ParseResults expanded, results is now", results) + elif issubclass(type(node), BaseFragment): + results.append(node) + #elif type(node) is Attention: + # #if node.weight < 1: + # # todo: inject a blend when flattening attention with weight <1" + # for c in node.children: + # results = flatten_internal(c, weight_scale*node.weight, results, prefix+' ') + elif type(node) is Blend: + flattened_subprompts = [] + #print(" flattening blend with prompts", node.prompts, "weights", node.weights) + for prompt in node.prompts: + # prompt is a list + flattened_subprompts = flatten_internal(prompt, weight_scale, flattened_subprompts, prefix+'B ') + results += [Blend(prompts=flattened_subprompts, weights=node.weights)] + elif type(node) is Prompt: + #print(prefix + "about to flatten Prompt with children", node.children) + flattened_prompt = [] + for child in node.children: + flattened_prompt = flatten_internal(child, weight_scale, flattened_prompt, prefix+'P ') + results += [FlattenedPrompt(parts=fuse_fragments(flattened_prompt))] + #print(prefix + "after flattening Prompt, results is", results) + else: + raise PromptParser.ParsingException(f"unhandled node type {type(node)} when flattening {node}") + #print(prefix + "-> after flattening", type(node), "results is", results) + return results + + #print("flattening", root) + + flattened_parts = [] + for part in root.prompts: + flattened_parts += flatten_internal(part, 1.0, [], ' C| ') + weights = root.weights + return Conjunction(flattened_parts, weights) + + + + def build_parser_logic(self): + + lparen = pp.Literal("(").suppress() + rparen = pp.Literal(")").suppress() + # accepts int or float notation, always maps to float + number = pyparsing.pyparsing_common.real | pp.Word(pp.nums).set_parse_action(pp.token_map(float)) + SPACE_CHARS = ' \t\n' + + prompt_part = pp.Forward() + word = pp.Word(pp.printables).set_parse_action(lambda x: Fragment(' '.join([s for s in x]))) + word.set_name("word") + word.set_debug(False) + + def make_fragment(x): + #print("### making fragment for", x) + if type(x) is str: + return Fragment(x) + elif type(x) is pp.ParseResults or type(x) is list: + return Fragment(' '.join([s for s in x])) + else: + raise PromptParser.ParsingException("Cannot make fragment from " + str(x)) + + + original_words = ( + (lparen + pp.Literal('"').suppress() + pp.CharsNotIn('"') + pp.Literal('"').suppress() + rparen).set_name('term1').set_debug(False) | + (pp.Literal('"').suppress() + pp.CharsNotIn('"') + pp.Literal('"').suppress()).set_name('term2').set_debug(False) | + (lparen + pp.CharsNotIn(')') + rparen).set_name('term3').set_debug(False) + ).set_name('original_words') + edited_words = ( + (pp.Literal('"').suppress() + pp.CharsNotIn('"') + pp.Literal('"').suppress()).set_name('termA').set_debug(False) | + pp.CharsNotIn(')').set_name('termB').set_debug(False) + ).set_name('edited_words') + cross_attention_substitute = original_words + \ + pp.Literal(".swap").suppress() + \ + lparen + edited_words + rparen + cross_attention_substitute.set_name('cross_attention_substitute') + + def make_cross_attention_substitute(x): + #print("making cacs for", x) + return CrossAttentionControlSubstitute(x[0], x[1]) + #print("made", cacs) + #return cacs + + cross_attention_substitute.set_parse_action(make_cross_attention_substitute) + + # simple fragments of text + prompt_part << (cross_attention_substitute + #| attention + | word + ) + prompt_part.set_debug(False) + prompt_part.set_name("prompt_part") + + # root prompt definition + prompt = pp.Group(pp.OneOrMore(prompt_part))\ + .set_parse_action(lambda x: Prompt(x[0])) + + # weighted blend of prompts + # ("promptA", "promptB").blend(a, b) where "promptA" and "promptB" are valid prompts and a and b are float or + # int weights. + # can specify more terms eg ("promptA", "promptB", "promptC").blend(a,b,c) + + def make_prompt_from_quoted_string(x): + #print(' got quoted prompt', x) + + x_unquoted = x[0][1:-1] + if len(x_unquoted.strip()) == 0: + # print(' b : just an empty string') + return Prompt([Fragment('')]) + # print(' b parsing ', c_unquoted) + x_parsed = prompt.parse_string(x_unquoted) + #print(" quoted prompt was parsed to", type(x_parsed),":", x_parsed) + return x_parsed[0] + + quoted_prompt = pp.dbl_quoted_string.set_parse_action(make_prompt_from_quoted_string) + quoted_prompt.set_name('quoted_prompt') + + blend_terms = pp.delimited_list(quoted_prompt).set_name('blend_terms') + blend_weights = pp.delimited_list(number).set_name('blend_weights') + blend = pp.Group(lparen + pp.Group(blend_terms) + rparen + + pp.Literal(".blend").suppress() + + lparen + pp.Group(blend_weights) + rparen).set_name('blend') + blend.set_debug(False) + + + blend.set_parse_action(lambda x: Blend(prompts=x[0][0], weights=x[0][1])) + + conjunction_terms = blend_terms.copy().set_name('conjunction_terms') + conjunction_weights = blend_weights.copy().set_name('conjunction_weights') + conjunction_with_parens_and_quotes = pp.Group(lparen + pp.Group(conjunction_terms) + rparen + + pp.Literal(".and").suppress() + + lparen + pp.Optional(pp.Group(conjunction_weights)) + rparen).set_name('conjunction') + def make_conjunction(x): + parts_raw = x[0][0] + weights = x[0][1] if len(x[0])>1 else [1.0]*len(parts_raw) + parts = [part for part in parts_raw] + return Conjunction(parts, weights) + conjunction_with_parens_and_quotes.set_parse_action(make_conjunction) + + implicit_conjunction = pp.OneOrMore(blend | prompt) + implicit_conjunction.set_parse_action(lambda x: Conjunction(x)) + + conjunction = conjunction_with_parens_and_quotes | implicit_conjunction + conjunction.set_debug(False) + + # top-level is a conjunction of one or more blends or prompts + return conjunction diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py index 4b62b5e393..57027b224c 100644 --- a/ldm/models/diffusion/ddpm.py +++ b/ldm/models/diffusion/ddpm.py @@ -820,21 +820,21 @@ class LatentDiffusion(DDPM): ) return self.scale_factor * z - def get_learned_conditioning(self, c): + def get_learned_conditioning(self, c, **kwargs): if self.cond_stage_forward is None: if hasattr(self.cond_stage_model, 'encode') and callable( self.cond_stage_model.encode ): c = self.cond_stage_model.encode( - c, embedding_manager=self.embedding_manager + c, embedding_manager=self.embedding_manager, **kwargs ) if isinstance(c, DiagonalGaussianDistribution): c = c.mode() else: - c = self.cond_stage_model(c) + c = self.cond_stage_model(c, **kwargs) else: assert hasattr(self.cond_stage_model, self.cond_stage_forward) - c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) + c = getattr(self.cond_stage_model, self.cond_stage_forward)(c, **kwargs) return c def meshgrid(self, h, w): diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index 12ef737134..8f4ad26119 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -1,3 +1,5 @@ +import math + import torch import torch.nn as nn from functools import partial @@ -449,11 +451,23 @@ class FrozenCLIPEmbedder(AbstractEncoder): tokens = batch_encoding['input_ids'].to(self.device) z = self.transformer(input_ids=tokens, **kwargs) - return z, tokens + if kwargs.get('return_tokens', False): + return z, tokens + else: + return z def encode(self, text, **kwargs): return self(text, **kwargs) +class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder): + @classmethod + def apply_embedding_weights(self, embeddings: torch.Tensor, per_embedding_weights: list[float], normalize:bool) -> torch.Tensor: + per_embedding_weights = torch.tensor(per_embedding_weights, dtype=embeddings.dtype, device=embeddings.device) + if normalize: + per_embedding_weights = per_embedding_weights / torch.sum(per_embedding_weights) + reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1, 1,)) + return torch.sum(embeddings * reshaped_weights, dim=1) + class FrozenCLIPTextEmbedder(nn.Module): """ diff --git a/tests/test_prompt_parser.py b/tests/test_prompt_parser.py new file mode 100644 index 0000000000..2ef56c47ae --- /dev/null +++ b/tests/test_prompt_parser.py @@ -0,0 +1,173 @@ +import unittest + +from ldm.invoke.prompt_parser import PromptParser, Blend, Conjunction, FlattenedPrompt, CrossAttentionControlSubstitute + + +def parse_prompt(prompt_string): + pp = PromptParser() + #print(f"parsing '{prompt_string}'") + parse_result = pp.parse(prompt_string) + #print(f"-> parsed '{prompt_string}' to {parse_result}") + return parse_result + +class PromptParserTestCase(unittest.TestCase): + + def test_empty(self): + self.assertEqual(Conjunction([FlattenedPrompt([('', 1)])]), parse_prompt('')) + + def test_basic(self): + self.assertEqual(Conjunction([FlattenedPrompt([('fire (flames)', 1)])]), parse_prompt("fire (flames)")) + self.assertEqual(Conjunction([FlattenedPrompt([("fire flames", 1)])]), parse_prompt("fire flames")) + self.assertEqual(Conjunction([FlattenedPrompt([("fire, flames", 1)])]), parse_prompt("fire, flames")) + self.assertEqual(Conjunction([FlattenedPrompt([("fire, flames , fire", 1)])]), parse_prompt("fire, flames , fire")) + + def test_attention(self): + self.assertEqual(Conjunction([FlattenedPrompt([('flames', 0.5)])]), parse_prompt("0.5(flames)")) + self.assertEqual(Conjunction([FlattenedPrompt([('fire flames', 0.5)])]), parse_prompt("0.5(fire flames)")) + self.assertEqual(Conjunction([FlattenedPrompt([('flames', 1.1)])]), parse_prompt("+(flames)")) + self.assertEqual(Conjunction([FlattenedPrompt([('flames', 0.9)])]), parse_prompt("-(flames)")) + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1), ('flames', 0.5)])]), parse_prompt("fire 0.5(flames)")) + self.assertEqual(Conjunction([FlattenedPrompt([('flames', pow(1.1, 2))])]), parse_prompt("++(flames)")) + self.assertEqual(Conjunction([FlattenedPrompt([('flames', pow(0.9, 2))])]), parse_prompt("--(flames)")) + self.assertEqual(Conjunction([FlattenedPrompt([('flowers', pow(0.9, 3)), ('flames', pow(1.1, 3))])]), parse_prompt("---(flowers) +++flames")) + self.assertEqual(Conjunction([FlattenedPrompt([('flowers', pow(0.9, 3)), ('flames', pow(1.1, 3))])]), parse_prompt("---(flowers) +++flames")) + self.assertEqual(Conjunction([FlattenedPrompt([('flowers', pow(0.9, 3)), ('flames+', pow(1.1, 3))])]), + parse_prompt("---(flowers) +++flames+")) + self.assertEqual(Conjunction([FlattenedPrompt([('pretty flowers', 1.1)])]), + parse_prompt("+(pretty flowers)")) + self.assertEqual(Conjunction([FlattenedPrompt([('pretty flowers', 1.1), (', the flames are too hot', 1)])]), + parse_prompt("+(pretty flowers), the flames are too hot")) + + def test_no_parens_attention_runon(self): + self.assertEqual(Conjunction([FlattenedPrompt([('fire', pow(1.1, 2)), ('flames', 1.0)])]), parse_prompt("++fire flames")) + self.assertEqual(Conjunction([FlattenedPrompt([('fire', pow(0.9, 2)), ('flames', 1.0)])]), parse_prompt("--fire flames")) + self.assertEqual(Conjunction([FlattenedPrompt([('flowers', 1.0), ('fire', pow(1.1, 2)), ('flames', 1.0)])]), parse_prompt("flowers ++fire flames")) + self.assertEqual(Conjunction([FlattenedPrompt([('flowers', 1.0), ('fire', pow(0.9, 2)), ('flames', 1.0)])]), parse_prompt("flowers --fire flames")) + + + def test_explicit_conjunction(self): + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])]), parse_prompt('("fire", "flames").and(1,1)')) + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])]), parse_prompt('("fire", "flames").and()')) + self.assertEqual( + Conjunction([FlattenedPrompt([('fire flames', 1.0)]), FlattenedPrompt([('mountain man', 1.0)])]), parse_prompt('("fire flames", "mountain man").and()')) + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 2.0)]), FlattenedPrompt([('flames', 0.9)])]), parse_prompt('("2.0(fire)", "-flames").and()')) + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)]), + FlattenedPrompt([('mountain man', 1.0)])]), parse_prompt('("fire", "flames", "mountain man").and()')) + + def test_conjunction_weights(self): + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])], weights=[2.0,1.0]), parse_prompt('("fire", "flames").and(2,1)')) + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])], weights=[1.0,2.0]), parse_prompt('("fire", "flames").and(1,2)')) + + with self.assertRaises(PromptParser.ParsingException): + parse_prompt('("fire", "flames").and(2)') + parse_prompt('("fire", "flames").and(2,1,2)') + + def test_complex_conjunction(self): + self.assertEqual(Conjunction([FlattenedPrompt([("mountain man", 1.0)]), FlattenedPrompt([("a person with a hat", 1.0), ("riding a bicycle", pow(1.1,2))])], weights=[0.5, 0.5]), + parse_prompt("(\"mountain man\", \"a person with a hat ++(riding a bicycle)\").and(0.5, 0.5)")) + + def test_badly_formed(self): + def make_untouched_prompt(prompt): + return Conjunction([FlattenedPrompt([(prompt, 1.0)])]) + + def assert_if_prompt_string_not_untouched(prompt): + self.assertEqual(make_untouched_prompt(prompt), parse_prompt(prompt)) + + assert_if_prompt_string_not_untouched('a test prompt') + assert_if_prompt_string_not_untouched('a badly (formed test prompt') + assert_if_prompt_string_not_untouched('a badly formed test+ prompt') + assert_if_prompt_string_not_untouched('a badly (formed test+ prompt') + assert_if_prompt_string_not_untouched('a badly (formed test+ )prompt') + assert_if_prompt_string_not_untouched('a badly (formed test+ )prompt') + assert_if_prompt_string_not_untouched('(((a badly (formed test+ )prompt') + assert_if_prompt_string_not_untouched('(a (ba)dly (f)ormed test+ prompt') + self.assertEqual(Conjunction([FlattenedPrompt([('(a (ba)dly (f)ormed test+', 1.0), ('prompt', 1.1)])]), + parse_prompt('(a (ba)dly (f)ormed test+ +prompt')) + self.assertEqual(Conjunction([Blend([FlattenedPrompt([('((a badly (formed test+', 1.0)])], weights=[1.0])]), + parse_prompt('("((a badly (formed test+ ").blend(1.0)')) + + def test_blend(self): + self.assertEqual(Conjunction( + [Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('fire flames', 1.0)])], [0.7, 0.3])]), + parse_prompt("(\"fire\", \"fire flames\").blend(0.7, 0.3)") + ) + self.assertEqual(Conjunction([Blend( + [FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('fire flames', 1.0)]), FlattenedPrompt([('hi', 1.0)])], + [0.7, 0.3, 1.0])]), + parse_prompt("(\"fire\", \"fire flames\", \"hi\").blend(0.7, 0.3, 1.0)") + ) + self.assertEqual(Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), + FlattenedPrompt([('fire flames', 1.0), ('hot', pow(1.1, 2))]), + FlattenedPrompt([('hi', 1.0)])], + weights=[0.7, 0.3, 1.0])]), + parse_prompt("(\"fire\", \"fire flames ++(hot)\", \"hi\").blend(0.7, 0.3, 1.0)") + ) + # blend a single entry is not a failure + self.assertEqual(Conjunction([Blend([FlattenedPrompt([('fire', 1.0)])], [0.7])]), + parse_prompt("(\"fire\").blend(0.7)") + ) + # blend with empty + self.assertEqual( + Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]), + parse_prompt("(\"fire\", \"\").blend(0.7, 1)") + ) + self.assertEqual( + Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]), + parse_prompt("(\"fire\", \" \").blend(0.7, 1)") + ) + self.assertEqual( + Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]), + parse_prompt("(\"fire\", \" \").blend(0.7, 1)") + ) + self.assertEqual( + Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([(',', 1.0)])], [0.7, 1.0])]), + parse_prompt("(\"fire\", \" , \").blend(0.7, 1)") + ) + + + def test_nested(self): + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0), ('flames', 2.0), ('trees', 3.0)])]), + parse_prompt('fire 2.0(flames 1.5(trees))')) + self.assertEqual(Conjunction([Blend(prompts=[FlattenedPrompt([('fire', 1.0), ('flames', 1.2100000000000002)]), + FlattenedPrompt([('mountain', 1.0), ('man', 2.0)])], + weights=[1.0, 1.0])]), + parse_prompt('("fire ++(flames)", "mountain 2(man)").blend(1,1)')) + + def test_cross_attention_control(self): + fire_flames_to_trees = Conjunction([FlattenedPrompt([('fire', 1.0), \ + CrossAttentionControlSubstitute('flames', 'trees')])]) + self.assertEqual(fire_flames_to_trees, parse_prompt('fire "flames".swap(trees)')) + self.assertEqual(fire_flames_to_trees, parse_prompt('fire (flames).swap(trees)')) + self.assertEqual(fire_flames_to_trees, parse_prompt('fire ("flames").swap(trees)')) + self.assertEqual(fire_flames_to_trees, parse_prompt('fire "flames".swap("trees")')) + self.assertEqual(fire_flames_to_trees, parse_prompt('fire (flames).swap("trees")')) + self.assertEqual(fire_flames_to_trees, parse_prompt('fire ("flames").swap("trees")')) + + fire_flames_to_trees_and_houses = Conjunction([FlattenedPrompt([('fire', 1.0), \ + CrossAttentionControlSubstitute('flames', 'trees and houses')])]) + self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire ("flames").swap("trees and houses")')) + self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire (flames).swap("trees and houses")')) + self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire "flames".swap("trees and houses")')) + + trees_and_houses_to_flames = Conjunction([FlattenedPrompt([('fire', 1.0), \ + CrossAttentionControlSubstitute('trees and houses', 'flames')])]) + self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire ("trees and houses").swap("flames")')) + self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire (trees and houses).swap("flames")')) + self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire "trees and houses".swap("flames")')) + self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire ("trees and houses").swap(flames)')) + self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire (trees and houses).swap(flames)')) + self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire "trees and houses".swap(flames)')) + + flames_to_trees_fire = Conjunction([FlattenedPrompt([ + CrossAttentionControlSubstitute('flames', 'trees'), + (', fire', 1.0)])]) + self.assertEqual(flames_to_trees_fire, parse_prompt('"flames".swap(trees), fire')) + self.assertEqual(flames_to_trees_fire, parse_prompt('(flames).swap(trees), fire ')) + self.assertEqual(flames_to_trees_fire, parse_prompt('("flames").swap(trees), fire ')) + self.assertEqual(flames_to_trees_fire, parse_prompt('"flames".swap("trees"), fire')) + self.assertEqual(flames_to_trees_fire, parse_prompt('(flames).swap("trees"), fire')) + self.assertEqual(flames_to_trees_fire, parse_prompt('("flames").swap("trees"), fire')) + + +if __name__ == '__main__': + unittest.main()