From ded3f13a331d490b99e1bc927d59d165159d1567 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Sun, 19 Feb 2023 20:42:29 +0100 Subject: [PATCH] move all prompting stuff to use compel --- invokeai/backend/invoke_ai_web_server.py | 2 +- ldm/invoke/CLI.py | 9 +- ldm/invoke/args.py | 2 +- ldm/invoke/conditioning.py | 295 +++----- ldm/invoke/generator/diffusers_pipeline.py | 12 +- ldm/invoke/prompt_parser.py | 655 ------------------ .../diffusion/cross_attention_control.py | 34 +- ldm/modules/prompt_to_embeddings_converter.py | 236 ------- ldm/modules/textual_inversion_manager.py | 5 +- pyproject.toml | 1 + tests/test_prompt_parser.py | 499 ------------- 11 files changed, 104 insertions(+), 1646 deletions(-) delete mode 100644 ldm/invoke/prompt_parser.py delete mode 100644 ldm/modules/prompt_to_embeddings_converter.py delete mode 100644 tests/test_prompt_parser.py diff --git a/invokeai/backend/invoke_ai_web_server.py b/invokeai/backend/invoke_ai_web_server.py index 8bb97a69d5..ca2761aa29 100644 --- a/invokeai/backend/invoke_ai_web_server.py +++ b/invokeai/backend/invoke_ai_web_server.py @@ -30,7 +30,7 @@ from ldm.invoke.generator.diffusers_pipeline import PipelineIntermediateState from ldm.invoke.generator.inpaint import infill_methods from ldm.invoke.globals import Globals, global_converted_ckpts_dir from ldm.invoke.pngwriter import PngWriter, retrieve_metadata -from ldm.invoke.prompt_parser import split_weighted_subprompts, Blend +from compel.prompt_parser import split_weighted_subprompts, Blend from ldm.invoke.globals import global_models_dir from ldm.invoke.merge_diffusers import merge_diffusion_models diff --git a/ldm/invoke/CLI.py b/ldm/invoke/CLI.py index d639c16640..01cfa50a87 100644 --- a/ldm/invoke/CLI.py +++ b/ldm/invoke/CLI.py @@ -9,6 +9,8 @@ from typing import List, Optional, Union import click +from compel import PromptParser + if sys.platform == "darwin": os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" @@ -23,7 +25,6 @@ from ldm.invoke.image_util import make_grid from ldm.invoke.log import write_log from ldm.invoke.model_manager import ModelManager from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata -from ldm.invoke.prompt_parser import PromptParser from ldm.invoke.readline import Completer, get_completer from ldm.util import url_attachment_name @@ -749,7 +750,7 @@ def import_ckpt_model( base_name = Path(url_attachment_name(path_or_url)).name if is_a_url else Path(path_or_url).name default_name = Path(base_name).stem default_description = f"Imported model {default_name}" - + model_name, model_description = _get_model_name_and_desc( manager, completer, @@ -834,7 +835,7 @@ def _ask_for_config_file(model_path: Union[str,Path], completer, plural: bool=Fa '2': 'v2-inference-v.yaml', '3': 'v1-inpainting-inference.yaml', } - + prompt = '''What type of models are these?: [1] Models based on Stable Diffusion 1.X [2] Models based on Stable Diffusion 2.X @@ -843,7 +844,7 @@ def _ask_for_config_file(model_path: Union[str,Path], completer, plural: bool=Fa [1] A model based on Stable Diffusion 1.X [2] A model based on Stable Diffusion 2.X [3] An inpainting models based on Stable Diffusion 1.X -[4] Something else''' +[4] Something else''' print(prompt) choice = input(f'Your choice: [{default}] ') choice = choice.strip() or default diff --git a/ldm/invoke/args.py b/ldm/invoke/args.py index 1bd1aa46ab..2e0aa5463b 100644 --- a/ldm/invoke/args.py +++ b/ldm/invoke/args.py @@ -93,9 +93,9 @@ import shlex import sys import ldm.invoke import ldm.invoke.pngwriter +from compel.prompt_parser import split_weighted_subprompts from ldm.invoke.globals import Globals -from ldm.invoke.prompt_parser import split_weighted_subprompts from argparse import Namespace from pathlib import Path diff --git a/ldm/invoke/conditioning.py b/ldm/invoke/conditioning.py index 542ef22767..778d8bc3e8 100644 --- a/ldm/invoke/conditioning.py +++ b/ldm/invoke/conditioning.py @@ -9,59 +9,75 @@ get_uc_and_c_and_ec() get the conditioned and unconditioned latent, an import re from typing import Union -import torch - -from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \ - CrossAttentionControlledFragment, CrossAttentionControlSubstitute, Fragment -from ..models.diffusion import cross_attention_control +from compel import Compel +from compel.prompt_parser import FlattenedPrompt, Blend, Fragment, CrossAttentionControlSubstitute +from .devices import torch_dtype from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent -from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder -from ..modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter from ldm.invoke.globals import Globals def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False): - # lazy-load any deferred textual inversions. # this might take a couple of seconds the first time a textual inversion is used. model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(prompt_string) - prompt, negative_prompt = get_prompt_structure(prompt_string, - skip_normalize_legacy_blend=skip_normalize_legacy_blend) - conditioning = _get_conditioning_for_prompt(prompt, negative_prompt, model, log_tokens) + compel = Compel(tokenizer=model.tokenizer, + text_encoder=model.text_encoder, + textual_inversion_manager=model.textual_inversion_manager, + dtype_for_device_getter=torch_dtype) - return conditioning + positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string) + positive_prompt = compel.parse_prompt_string(positive_prompt_string) + negative_prompt = compel.parse_prompt_string(negative_prompt_string) + + if log_tokens or getattr(Globals, "log_tokenization", False): + log_tokenization(positive_prompt, negative_prompt, tokenizer=model.tokenizer) + + c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt) + uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt) + + tokens_count = get_tokens_for_prompt(tokenizer=model.tokenizer, parsed_prompt=positive_prompt) + + ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count, + cross_attention_control_args=options.get( + 'cross_attention_control', None)) + return uc, c, ec -def get_prompt_structure(prompt_string, skip_normalize_legacy_blend: bool = False) -> ( -Union[FlattenedPrompt, Blend], FlattenedPrompt): +def get_prompt_structure(prompt_string, model, skip_normalize_legacy_blend: bool = False) -> ( + Union[FlattenedPrompt, Blend], FlattenedPrompt): """ parse the passed-in prompt string and return tuple (positive_prompt, negative_prompt) """ - prompt, negative_prompt = _parse_prompt_string(prompt_string, - skip_normalize_legacy_blend=skip_normalize_legacy_blend) - return prompt, negative_prompt + compel = Compel(tokenizer=model.tokenizer, + text_encoder=model.text_encoder, + textual_inversion_manager=model.textual_inversion_manager, + dtype_for_device_getter=torch_dtype) + + positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string) + positive_prompt = compel.parse_prompt_string(positive_prompt_string) + negative_prompt = compel.parse_prompt_string(negative_prompt_string) + + return positive_prompt, negative_prompt -def get_tokens_for_prompt(model, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True) -> [str]: +def get_tokens_for_prompt(tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True) -> [str]: text_fragments = [x.text if type(x) is Fragment else (" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else str(x)) for x in parsed_prompt.children] text = " ".join(text_fragments) - tokens = model.cond_stage_model.tokenizer.tokenize(text) + tokens = tokenizer.tokenize(text) if truncate_if_too_long: - max_tokens_length = model.cond_stage_model.max_length - 2 # typically 75 + max_tokens_length = tokenizer.model_max_length - 2 # typically 75 tokens = tokens[0:max_tokens_length] return tokens -def _parse_prompt_string(prompt_string_uncleaned, skip_normalize_legacy_blend=False) -> Union[FlattenedPrompt, Blend]: - # Extract Unconditioned Words From Prompt +def split_prompt_to_positive_and_negative(prompt_string_uncleaned): unconditioned_words = '' unconditional_regex = r'\[(.*?)\]' unconditionals = re.findall(unconditional_regex, prompt_string_uncleaned) - if len(unconditionals) > 0: unconditioned_words = ' '.join(unconditionals) @@ -71,210 +87,57 @@ def _parse_prompt_string(prompt_string_uncleaned, skip_normalize_legacy_blend=Fa prompt_string_cleaned = re.sub(' +', ' ', clean_prompt) else: prompt_string_cleaned = prompt_string_uncleaned - - pp = PromptParser() - - parsed_prompt: Union[FlattenedPrompt, Blend] = None - legacy_blend: Blend = pp.parse_legacy_blend(prompt_string_cleaned, skip_normalize_legacy_blend) - if legacy_blend is not None: - parsed_prompt = legacy_blend - else: - # we don't support conjunctions for now - parsed_prompt = pp.parse_conjunction(prompt_string_cleaned).prompts[0] - - parsed_negative_prompt: FlattenedPrompt = pp.parse_conjunction(unconditioned_words).prompts[0] - return parsed_prompt, parsed_negative_prompt + return prompt_string_cleaned, unconditioned_words -def _get_conditioning_for_prompt(parsed_prompt: Union[Blend, FlattenedPrompt], parsed_negative_prompt: FlattenedPrompt, - model, log_tokens=False) \ - -> tuple[torch.Tensor, torch.Tensor, InvokeAIDiffuserComponent.ExtraConditioningInfo]: - """ - Process prompt structure and tokens, and return (conditioning, unconditioning, extra_conditioning_info) - """ +def log_tokenization(positive_prompt: Blend | FlattenedPrompt, + negative_prompt: Blend | FlattenedPrompt, + tokenizer): + print(f"\n>> [TOKENLOG] Parsed Prompt: {positive_prompt}") + print(f"\n>> [TOKENLOG] Parsed Negative Prompt: {negative_prompt}") - if log_tokens or getattr(Globals, "log_tokenization", False): - print(f"\n>> [TOKENLOG] Parsed Prompt: {parsed_prompt}") - print(f"\n>> [TOKENLOG] Parsed Negative Prompt: {parsed_negative_prompt}") + log_tokenization_for_prompt_object(positive_prompt, tokenizer) + log_tokenization_for_prompt_object(negative_prompt, tokenizer, display_label_prefix="(negative prompt)") - conditioning = None - cac_args: cross_attention_control.Arguments = None - if type(parsed_prompt) is Blend: - conditioning = _get_conditioning_for_blend(model, parsed_prompt, log_tokens) - elif type(parsed_prompt) is FlattenedPrompt: - if parsed_prompt.wants_cross_attention_control: - conditioning, cac_args = _get_conditioning_for_cross_attention_control(model, parsed_prompt, log_tokens) +def log_tokenization_for_prompt_object(p: Blend | FlattenedPrompt, tokenizer, display_label_prefix=None): + display_label_prefix = display_label_prefix or "" + if type(p) is Blend: + blend: Blend = p + for i, c in enumerate(blend.prompts): + log_tokenization_for_prompt_object( + c, tokenizer, + display_label_prefix=f"{display_label_prefix}(blend part {i + 1}, weight={blend.weights[i]})") + elif type(p) is FlattenedPrompt: + flattened_prompt: FlattenedPrompt = p + if flattened_prompt.wants_cross_attention_control: + original_fragments = [] + edited_fragments = [] + for f in flattened_prompt.children: + if type(f) is CrossAttentionControlSubstitute: + original_fragments += f.original + edited_fragments += f.edited + else: + original_fragments.append(f) + edited_fragments.append(f) + original_text = " ".join([x.text for x in original_fragments]) + log_tokenization_for_text(original_text, tokenizer, + display_label=f"{display_label_prefix}(.swap originals)") + edited_text = " ".join([x.text for x in edited_fragments]) + log_tokenization_for_text(edited_text, tokenizer, + display_label=f"{display_label_prefix}(.swap replacements)") else: - conditioning, _ = _get_embeddings_and_tokens_for_prompt(model, - parsed_prompt, - log_tokens=log_tokens, - log_display_label="(prompt)") - else: - raise ValueError(f"parsed_prompt is '{type(parsed_prompt)}' which is not a supported prompt type") - - unconditioning, _ = _get_embeddings_and_tokens_for_prompt(model, - parsed_negative_prompt, - log_tokens=log_tokens, - log_display_label="(unconditioning)") - if isinstance(conditioning, dict): - # hybrid conditioning is in play - unconditioning, conditioning = _flatten_hybrid_conditioning(unconditioning, conditioning) - if cac_args is not None: - print( - ">> Hybrid conditioning cannot currently be combined with cross attention control. Cross attention control will be ignored.") - cac_args = None - - if type(parsed_prompt) is Blend: - blend: Blend = parsed_prompt - all_token_sequences = [get_tokens_for_prompt(model, p) for p in blend.prompts] - longest_token_sequence = max(all_token_sequences, key=lambda t: len(t)) - eos_token_index = len(longest_token_sequence)+1 - else: - tokens = get_tokens_for_prompt(model, parsed_prompt) - eos_token_index = len(tokens)+1 - return ( - unconditioning, conditioning, InvokeAIDiffuserComponent.ExtraConditioningInfo( - tokens_count_including_eos_bos=eos_token_index + 1, - cross_attention_control_args=cac_args - ) - ) + text = " ".join([x.text for x in flattened_prompt.children]) + log_tokenization_for_text(text, tokenizer, display_label=display_label_prefix) -def _get_conditioning_for_cross_attention_control(model, prompt: FlattenedPrompt, log_tokens: bool = True): - original_prompt = FlattenedPrompt() - edited_prompt = FlattenedPrompt() - # for name, a0, a1, b0, b1 in edit_opcodes: only name == 'equal' is currently parsed - original_token_count = 0 - edited_token_count = 0 - edit_options = [] - edit_opcodes = [] - # beginning of sequence - edit_opcodes.append( - ('equal', original_token_count, original_token_count + 1, edited_token_count, edited_token_count + 1)) - edit_options.append(None) - original_token_count += 1 - edited_token_count += 1 - for fragment in prompt.children: - if type(fragment) is CrossAttentionControlSubstitute: - original_prompt.append(fragment.original) - edited_prompt.append(fragment.edited) - - to_replace_token_count = _get_tokens_length(model, fragment.original) - replacement_token_count = _get_tokens_length(model, fragment.edited) - edit_opcodes.append(('replace', - original_token_count, original_token_count + to_replace_token_count, - edited_token_count, edited_token_count + replacement_token_count - )) - original_token_count += to_replace_token_count - edited_token_count += replacement_token_count - edit_options.append(fragment.options) - # elif type(fragment) is CrossAttentionControlAppend: - # edited_prompt.append(fragment.fragment) - else: - # regular fragment - original_prompt.append(fragment) - edited_prompt.append(fragment) - - count = _get_tokens_length(model, [fragment]) - edit_opcodes.append(('equal', original_token_count, original_token_count + count, edited_token_count, - edited_token_count + count)) - edit_options.append(None) - original_token_count += count - edited_token_count += count - # end of sequence - edit_opcodes.append( - ('equal', original_token_count, original_token_count + 1, edited_token_count, edited_token_count + 1)) - edit_options.append(None) - original_token_count += 1 - edited_token_count += 1 - original_embeddings, original_tokens = _get_embeddings_and_tokens_for_prompt(model, - original_prompt, - log_tokens=log_tokens, - log_display_label="(.swap originals)") - # naïvely building a single edited_embeddings like this disregards the effects of changing the absolute location of - # subsequent tokens when there is >1 edit and earlier edits change the total token count. - # eg "a cat.swap(smiling dog, s_start=0.5) eating a hotdog.swap(pizza)" - when the 'pizza' edit is active but the - # 'cat' edit is not, the 'pizza' feature vector will nevertheless be affected by the introduction of the extra - # token 'smiling' in the inactive 'cat' edit. - # todo: build multiple edited_embeddings, one for each edit, and pass just the edited fragments through to the CrossAttentionControl functions - edited_embeddings, edited_tokens = _get_embeddings_and_tokens_for_prompt(model, - edited_prompt, - log_tokens=log_tokens, - log_display_label="(.swap replacements)") - conditioning = original_embeddings - edited_conditioning = edited_embeddings - # print('>> got edit_opcodes', edit_opcodes, 'options', edit_options) - cac_args = cross_attention_control.Arguments( - edited_conditioning=edited_conditioning, - edit_opcodes=edit_opcodes, - edit_options=edit_options - ) - return conditioning, cac_args - - -def _get_conditioning_for_blend(model, blend: Blend, log_tokens: bool = False): - embeddings_to_blend = None - for i, flattened_prompt in enumerate(blend.prompts): - this_embedding, _ = _get_embeddings_and_tokens_for_prompt(model, - flattened_prompt, - log_tokens=log_tokens, - log_display_label=f"(blend part {i + 1}, weight={blend.weights[i]})") - embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat( - (embeddings_to_blend, this_embedding)) - conditioning = WeightedPromptFragmentsToEmbeddingsConverter.apply_embedding_weights(embeddings_to_blend.unsqueeze(0), - blend.weights, - normalize=blend.normalize_weights) - return conditioning - - -def _get_embeddings_and_tokens_for_prompt(model, flattened_prompt: FlattenedPrompt, log_tokens: bool = False, - log_display_label: str = None): - if type(flattened_prompt) is not FlattenedPrompt: - raise Exception(f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead") - fragments = [x.text for x in flattened_prompt.children] - weights = [x.weight for x in flattened_prompt.children] - embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True, fragment_weights=[weights]) - if log_tokens or getattr(Globals, "log_tokenization", False): - text = " ".join(fragments) - log_tokenization(text, model, display_label=log_display_label) - - return embeddings, tokens - - -def _get_tokens_length(model, fragments: list[Fragment]): - fragment_texts = [x.text for x in fragments] - tokens = model.cond_stage_model.get_token_ids(fragment_texts, include_start_and_end_markers=False) - return sum([len(x) for x in tokens]) - - -def _flatten_hybrid_conditioning(uncond, cond): - ''' - This handles the choice between a conditional conditioning - that is a tensor (used by cross attention) vs one that has additional - dimensions as well, as used by 'hybrid' - ''' - assert isinstance(uncond, dict) - assert isinstance(cond, dict) - cond_flattened = dict() - for k in cond: - if isinstance(cond[k], list): - cond_flattened[k] = [ - torch.cat([uncond[k][i], cond[k][i]]) - for i in range(len(cond[k])) - ] - else: - cond_flattened[k] = torch.cat([uncond[k], cond[k]]) - return uncond, cond_flattened - - -def log_tokenization(text, model, display_label=None): +def log_tokenization_for_text(text, tokenizer, display_label=None): """ shows how the prompt is tokenized # usually tokens have '' to indicate end-of-word, # but for readability it has been replaced with ' ' """ - tokens = model.cond_stage_model.tokenizer.tokenize(text) + tokens = tokenizer.tokenize(text) tokenized = "" discarded = "" usedTokens = 0 @@ -284,7 +147,7 @@ def log_tokenization(text, model, display_label=None): token = tokens[i].replace('', ' ') # alternate color s = (usedTokens % 6) + 1 - if i < model.cond_stage_model.max_length: + if i < tokenizer.model_max_length: tokenized = tokenized + f"\x1b[0;3{s};40m{token}" usedTokens += 1 else: # over max token length @@ -293,7 +156,7 @@ def log_tokenization(text, model, display_label=None): if usedTokens > 0: print(f'\n>> [TOKENLOG] Tokens {display_label or ""} ({usedTokens}):') print(f'{tokenized}\x1b[0m') - + if discarded != "": print(f'\n>> [TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):') print(f'{discarded}\x1b[0m') diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index 6755bbb880..7072cc4c54 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -31,7 +31,7 @@ from ldm.modules.textual_inversion_manager import TextualInversionManager from ..devices import normalize_device, CPU_DEVICE from ..offloading import LazilyLoadedModelGroup, FullyLoadedModelGroup, ModelGroup from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver -from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter +from compel import EmbeddingsProvider @dataclass @@ -294,7 +294,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): text_encoder=self.text_encoder, full_precision=use_full_precision) # InvokeAI's interface for text embeddings and whatnot - self.prompt_fragments_to_embeddings_converter = WeightedPromptFragmentsToEmbeddingsConverter( + self.embeddings_provider = EmbeddingsProvider( tokenizer=self.tokenizer, text_encoder=self.text_encoder, textual_inversion_manager=self.textual_inversion_manager @@ -726,15 +726,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): """ Compatibility function for ldm.models.diffusion.ddpm.LatentDiffusion. """ - return self.prompt_fragments_to_embeddings_converter.get_embeddings_for_weighted_prompt_fragments( - text=c, - fragment_weights=fragment_weights, + return self.embeddings_provider.get_embeddings_for_weighted_prompt_fragments( + text_batch=c, + fragment_weights_batch=fragment_weights, should_return_tokens=return_tokens, device=self._model_group.device_for(self.unet)) @property def cond_stage_model(self): - return self.prompt_fragments_to_embeddings_converter + return self.embeddings_provider @torch.inference_mode() def _tokenize(self, prompt: Union[str, List[str]]): diff --git a/ldm/invoke/prompt_parser.py b/ldm/invoke/prompt_parser.py deleted file mode 100644 index 248243106e..0000000000 --- a/ldm/invoke/prompt_parser.py +++ /dev/null @@ -1,655 +0,0 @@ -import string -from typing import Union, Optional -import re -import pyparsing as pp -''' -This module parses prompt strings and produces tree-like structures that can be used generate and control the conditioning tensors. -weighted subprompts. - -Useful class exports: - -PromptParser - parses prompts - -Useful function exports: - -split_weighted_subpromopts() split subprompts, normalize and weight them -log_tokenization() print out colour-coded tokens and warn if truncated -''' - -class Prompt(): - """ - Mid-level structure for storing the tree-like result of parsing a prompt. A Prompt may not represent the whole of - the singular user-defined "prompt string" (although it can) - for example, if the user specifies a Blend, the objects - that are to be blended together are stored individuall as Prompt objects. - - Nesting makes this object not suitable for directly tokenizing; instead call flatten() on the containing Conjunction - to produce a FlattenedPrompt. - """ - def __init__(self, parts: list): - for c in parts: - if type(c) is not Attention and not issubclass(type(c), BaseFragment) and type(c) is not pp.ParseResults: - raise PromptParser.ParsingException(f"Prompt cannot contain {type(c).__name__} ({c}), only {[c.__name__ for c in BaseFragment.__subclasses__()]} are allowed") - self.children = parts - def __repr__(self): - return f"Prompt:{self.children}" - def __eq__(self, other): - return type(other) is Prompt and other.children == self.children - -class BaseFragment: - pass - -class FlattenedPrompt(): - """ - A Prompt that has been passed through flatten(). Its children can be readily tokenized. - """ - def __init__(self, parts: list=[]): - self.children = [] - for part in parts: - self.append(part) - - def append(self, fragment: Union[list, BaseFragment, tuple]): - # verify type correctness - if type(fragment) is list: - for x in fragment: - self.append(x) - elif issubclass(type(fragment), BaseFragment): - self.children.append(fragment) - elif type(fragment) is tuple: - # upgrade tuples to Fragments - if type(fragment[0]) is not str or (type(fragment[1]) is not float and type(fragment[1]) is not int): - raise PromptParser.ParsingException( - f"FlattenedPrompt cannot contain {fragment}, only Fragments or (str, float) tuples are allowed") - self.children.append(Fragment(fragment[0], fragment[1])) - else: - raise PromptParser.ParsingException( - f"FlattenedPrompt cannot contain {fragment}, only Fragments or (str, float) tuples are allowed") - - @property - def is_empty(self): - return len(self.children) == 0 or \ - (len(self.children) == 1 and len(self.children[0].text) == 0) - - @property - def wants_cross_attention_control(self): - return any( - [issubclass(type(x), CrossAttentionControlledFragment) for x in self.children] - ) - - def __repr__(self): - return f"FlattenedPrompt:{self.children}" - def __eq__(self, other): - return type(other) is FlattenedPrompt and other.children == self.children - - -class Fragment(BaseFragment): - """ - A Fragment is a chunk of plain text and an optional weight. The text should be passed as-is to the CLIP tokenizer. - """ - def __init__(self, text: str, weight: float=1): - assert(type(text) is str) - if '\\"' in text or '\\(' in text or '\\)' in text: - #print("Fragment converting escaped \( \) \\\" into ( ) \"") - text = text.replace('\\(', '(').replace('\\)', ')').replace('\\"', '"') - self.text = text - self.weight = float(weight) - - def __repr__(self): - return "Fragment:'"+self.text+"'@"+str(self.weight) - def __eq__(self, other): - return type(other) is Fragment \ - and other.text == self.text \ - and other.weight == self.weight - -class Attention(): - """ - Nestable weight control for fragments. Each object in the children array may in turn be an Attention object; - weights should be considered to accumulate as the tree is traversed to deeper levels of nesting. - - Do not traverse directly; instead obtain a FlattenedPrompt by calling Flatten() on a top-level Conjunction object. - """ - def __init__(self, weight: float, children: list): - if type(weight) is not float: - raise PromptParser.ParsingException( - f"Attention weight must be float (got {type(weight).__name__} {weight})") - self.weight = weight - if type(children) is not list: - raise PromptParser.ParsingException(f"cannot make Attention with non-list of children (got {type(children)})") - assert(type(children) is list) - self.children = children - #print(f"A: requested attention '{children}' to {weight}") - - def __repr__(self): - return f"Attention:{self.children} * {self.weight}" - def __eq__(self, other): - return type(other) is Attention and other.weight == self.weight and other.fragment == self.fragment - -class CrossAttentionControlledFragment(BaseFragment): - pass - -class CrossAttentionControlSubstitute(CrossAttentionControlledFragment): - """ - A Cross-Attention Controlled ('prompt2prompt') fragment, for use inside a Prompt, Attention, or FlattenedPrompt. - Representing an "original" word sequence that supplies feature vectors for an initial diffusion operation, and an - "edited" word sequence, to which the attention maps produced by the "original" word sequence are applied. Intuitively, - the result should be an "edited" image that looks like the "original" image with concepts swapped. - - eg "a cat sitting on a car" (original) -> "a smiling dog sitting on a car" (edited): the edited image should look - almost exactly the same as the original, but with a smiling dog rendered in place of the cat. The - CrossAttentionControlSubstitute object representing this swap may be confined to the tokens being swapped: - CrossAttentionControlSubstitute(original=[Fragment('cat')], edited=[Fragment('dog')]) - or it may represent a larger portion of the token sequence: - CrossAttentionControlSubstitute(original=[Fragment('a cat sitting on a car')], - edited=[Fragment('a smiling dog sitting on a car')]) - - In either case expect it to be embedded in a Prompt or FlattenedPrompt: - FlattenedPrompt([ - Fragment('a'), - CrossAttentionControlSubstitute(original=[Fragment('cat')], edited=[Fragment('dog')]), - Fragment('sitting on a car') - ]) - """ - def __init__(self, original: list, edited: list, options: dict=None): - self.original = original if len(original)>0 else [Fragment('')] - self.edited = edited if len(edited)>0 else [Fragment('')] - - default_options = { - 's_start': 0.0, - 's_end': 0.2062994740159002, # ~= shape_freedom=0.5 - 't_start': 0.1, - 't_end': 1.0 - } - merged_options = default_options - if options is not None: - shape_freedom = options.pop('shape_freedom', None) - if shape_freedom is not None: - # high shape freedom = SD can do what it wants with the shape of the object - # high shape freedom => s_end = 0 - # low shape freedom => s_end = 1 - # shape freedom is in a "linear" space, while noticeable changes to s_end are typically closer around 0, - # and there is very little perceptible difference as s_end increases above 0.5 - # so for shape_freedom = 0.5 we probably want s_end to be 0.2 - # -> cube root and subtract from 1.0 - merged_options['s_end'] = 1.0 - shape_freedom ** (1. / 3.) - #print('converted shape_freedom argument to', merged_options) - merged_options.update(options) - - self.options = merged_options - - def __repr__(self): - return f"CrossAttentionControlSubstitute:({self.original}->{self.edited} ({self.options})" - def __eq__(self, other): - return type(other) is CrossAttentionControlSubstitute \ - and other.original == self.original \ - and other.edited == self.edited \ - and other.options == self.options - - -class CrossAttentionControlAppend(CrossAttentionControlledFragment): - def __init__(self, fragment: Fragment): - self.fragment = fragment - def __repr__(self): - return "CrossAttentionControlAppend:",self.fragment - def __eq__(self, other): - return type(other) is CrossAttentionControlAppend \ - and other.fragment == self.fragment - - - -class Conjunction(): - """ - Storage for one or more Prompts or Blends, each of which is to be separately diffused and then the results merged - by weighted sum in latent space. - """ - def __init__(self, prompts: list, weights: list = None): - # force everything to be a Prompt - #print("making conjunction with", prompts, "types", [type(p).__name__ for p in prompts]) - self.prompts = [x if (type(x) is Prompt - or type(x) is Blend - or type(x) is FlattenedPrompt) - else Prompt(x) for x in prompts] - self.weights = [1.0]*len(self.prompts) if (weights is None or len(weights)==0) else list(weights) - if len(self.weights) != len(self.prompts): - raise PromptParser.ParsingException(f"while parsing Conjunction: mismatched parts/weights counts {prompts}, {weights}") - self.type = 'AND' - - def __repr__(self): - return f"Conjunction:{self.prompts} | weights {self.weights}" - def __eq__(self, other): - return type(other) is Conjunction \ - and other.prompts == self.prompts \ - and other.weights == self.weights - - -class Blend(): - """ - Stores a Blend of multiple Prompts. To apply, build feature vectors for each of the child Prompts and then perform a - weighted blend of the feature vectors to produce a single feature vector that is effectively a lerp between the - Prompts. - """ - def __init__(self, prompts: list, weights: list[float], normalize_weights: bool=True): - #print("making Blend with prompts", prompts, "and weights", weights) - weights = [1.0]*len(prompts) if (weights is None or len(weights)==0) else list(weights) - if len(prompts) != len(weights): - raise PromptParser.ParsingException(f"while parsing Blend: mismatched prompts/weights counts {prompts}, {weights}") - for p in prompts: - if type(p) is not Prompt and type(p) is not FlattenedPrompt: - raise(PromptParser.ParsingException(f"{type(p)} cannot be added to a Blend, only Prompts or FlattenedPrompts")) - for f in p.children: - if isinstance(f, CrossAttentionControlSubstitute): - raise(PromptParser.ParsingException(f"while parsing Blend: sorry, you cannot do .swap() as part of a Blend")) - - # upcast all lists to Prompt objects - self.prompts = [x if (type(x) is Prompt or type(x) is FlattenedPrompt) - else Prompt(x) - for x in prompts] - self.prompts = prompts - self.weights = weights - self.normalize_weights = normalize_weights - - @property - def wants_cross_attention_control(self): - # blends cannot cross-attention control - return False - - - def __repr__(self): - return f"Blend:{self.prompts} | weights {' ' if self.normalize_weights else '(non-normalized) '}{self.weights}" - def __eq__(self, other): - return other.__repr__() == self.__repr__() - - -class PromptParser(): - - class ParsingException(Exception): - pass - - class UnrecognizedOperatorException(ParsingException): - def __init__(self, operator:str): - super().__init__("Unrecognized operator: " + operator) - - def __init__(self, attention_plus_base=1.1, attention_minus_base=0.9): - - self.conjunction, self.prompt = build_parser_syntax(attention_plus_base, attention_minus_base) - - - def parse_conjunction(self, prompt: str) -> Conjunction: - ''' - :param prompt: The prompt string to parse - :return: a Conjunction representing the parsed results. - ''' - #print(f"!!parsing '{prompt}'") - - if len(prompt.strip()) == 0: - return Conjunction(prompts=[FlattenedPrompt([('', 1.0)])], weights=[1.0]) - - root = self.conjunction.parse_string(prompt) - #print(f"'{prompt}' parsed to root", root) - #fused = fuse_fragments(parts) - #print("fused to", fused) - - return self.flatten(root[0]) - - def parse_legacy_blend(self, text: str, skip_normalize: bool = False) -> Optional[Blend]: - weighted_subprompts = split_weighted_subprompts(text, skip_normalize=skip_normalize) - if len(weighted_subprompts) <= 1: - return None - strings = [x[0] for x in weighted_subprompts] - weights = [x[1] for x in weighted_subprompts] - - parsed_conjunctions = [self.parse_conjunction(x) for x in strings] - flattened_prompts = [x.prompts[0] for x in parsed_conjunctions] - - return Blend(prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize) - - - def flatten(self, root: Conjunction, verbose = False) -> Conjunction: - """ - Flattening a Conjunction traverses all of the nested tree-like structures in each of its Prompts or Blends, - producing from each of these walks a linear sequence of Fragment or CrossAttentionControlSubstitute objects - that can be readily tokenized without the need to walk a complex tree structure. - - :param root: The Conjunction to flatten. - :return: A Conjunction containing the result of flattening each of the prompts in the passed-in root. - """ - - def fuse_fragments(items): - # print("fusing fragments in ", items) - result = [] - for x in items: - if type(x) is CrossAttentionControlSubstitute: - original_fused = fuse_fragments(x.original) - edited_fused = fuse_fragments(x.edited) - result.append(CrossAttentionControlSubstitute(original_fused, edited_fused, options=x.options)) - else: - last_weight = result[-1].weight \ - if (len(result) > 0 and not issubclass(type(result[-1]), CrossAttentionControlledFragment)) \ - else None - this_text = x.text - this_weight = x.weight - if last_weight is not None and last_weight == this_weight: - last_text = result[-1].text - result[-1] = Fragment(last_text + ' ' + this_text, last_weight) - else: - result.append(x) - return result - - def flatten_internal(node, weight_scale, results, prefix): - verbose and print(prefix + "flattening", node, "...") - if type(node) is pp.ParseResults or type(node) is list: - for x in node: - results = flatten_internal(x, weight_scale, results, prefix+' pr ') - #print(prefix, " ParseResults expanded, results is now", results) - elif type(node) is Attention: - # if node.weight < 1: - # todo: inject a blend when flattening attention with weight <1" - for index,c in enumerate(node.children): - results = flatten_internal(c, weight_scale * node.weight, results, prefix + f" att{index} ") - elif type(node) is Fragment: - results += [Fragment(node.text, node.weight*weight_scale)] - elif type(node) is CrossAttentionControlSubstitute: - original = flatten_internal(node.original, weight_scale, [], prefix + ' CAo ') - edited = flatten_internal(node.edited, weight_scale, [], prefix + ' CAe ') - results += [CrossAttentionControlSubstitute(original, edited, options=node.options)] - elif type(node) is Blend: - flattened_subprompts = [] - #print(" flattening blend with prompts", node.prompts, "weights", node.weights) - for prompt in node.prompts: - # prompt is a list - flattened_subprompts = flatten_internal(prompt, weight_scale, flattened_subprompts, prefix+'B ') - results += [Blend(prompts=flattened_subprompts, weights=node.weights, normalize_weights=node.normalize_weights)] - elif type(node) is Prompt: - #print(prefix + "about to flatten Prompt with children", node.children) - flattened_prompt = [] - for child in node.children: - flattened_prompt = flatten_internal(child, weight_scale, flattened_prompt, prefix+'P ') - results += [FlattenedPrompt(parts=fuse_fragments(flattened_prompt))] - #print(prefix + "after flattening Prompt, results is", results) - else: - raise PromptParser.ParsingException(f"unhandled node type {type(node)} when flattening {node}") - verbose and print(prefix + "-> after flattening", type(node).__name__, "results is", results) - return results - - verbose and print("flattening", root) - - flattened_parts = [] - for part in root.prompts: - flattened_parts += flatten_internal(part, 1.0, [], ' C| ') - - verbose and print("flattened to", flattened_parts) - - weights = root.weights - return Conjunction(flattened_parts, weights) - - - - -def build_parser_syntax(attention_plus_base: float, attention_minus_base: float): - def make_operator_object(x): - #print('making operator for', x) - target = x[0] - operator = x[1] - arguments = x[2] - if operator == '.attend': - weight_raw = arguments[0] - weight = 1.0 - if type(weight_raw) is float or type(weight_raw) is int: - weight = weight_raw - elif type(weight_raw) is str: - base = attention_plus_base if weight_raw[0] == '+' else attention_minus_base - weight = pow(base, len(weight_raw)) - return Attention(weight=weight, children=[x for x in x[0]]) - elif operator == '.swap': - return CrossAttentionControlSubstitute(target, arguments, x.as_dict()) - elif operator == '.blend': - prompts = [Prompt(p) for p in x[0]] - weights_raw = x[2] - normalize_weights = True - if len(weights_raw) > 0 and weights_raw[-1][0] == 'no_normalize': - normalize_weights = False - weights_raw = weights_raw[:-1] - weights = [float(w[0]) for w in weights_raw] - return Blend(prompts=prompts, weights=weights, normalize_weights=normalize_weights) - elif operator == '.and' or operator == '.add': - prompts = [Prompt(p) for p in x[0]] - weights = [float(w[0]) for w in x[2]] - return Conjunction(prompts=prompts, weights=weights) - - raise PromptParser.UnrecognizedOperatorException(operator) - - def parse_fragment_str(x, expression: pp.ParseExpression, in_quotes: bool = False, in_parens: bool = False): - #print(f"parsing fragment string for {x}") - fragment_string = x[0] - if len(fragment_string.strip()) == 0: - return Fragment('') - - if in_quotes: - # escape unescaped quotes - fragment_string = fragment_string.replace('"', '\\"') - - try: - result = (expression + pp.StringEnd()).parse_string(fragment_string) - #print("parsed to", result) - return result - except pp.ParseException as e: - #print("parse_fragment_str couldn't parse prompt string:", e) - raise - - # meaningful symbols - lparen = pp.Literal("(").suppress() - rparen = pp.Literal(")").suppress() - quote = pp.Literal('"').suppress() - comma = pp.Literal(",").suppress() - dot = pp.Literal(".").suppress() - equals = pp.Literal("=").suppress() - - escaped_lparen = pp.Literal('\\(') - escaped_rparen = pp.Literal('\\)') - escaped_quote = pp.Literal('\\"') - escaped_comma = pp.Literal('\\,') - escaped_dot = pp.Literal('\\.') - escaped_plus = pp.Literal('\\+') - escaped_minus = pp.Literal('\\-') - escaped_equals = pp.Literal('\\=') - - syntactic_symbols = { - '(': escaped_lparen, - ')': escaped_rparen, - '"': escaped_quote, - ',': escaped_comma, - '.': escaped_dot, - '+': escaped_plus, - '-': escaped_minus, - '=': escaped_equals, - } - syntactic_chars = "".join(syntactic_symbols.keys()) - - # accepts int or float notation, always maps to float - number = pp.pyparsing_common.real | \ - pp.Combine(pp.Optional("-")+pp.Word(pp.nums)).set_parse_action(pp.token_map(float)) - - # for options - keyword = pp.Word(pp.alphanums + '_') - - # a word that absolutely does not contain any meaningful syntax - non_syntax_word = pp.Combine(pp.OneOrMore(pp.MatchFirst([ - pp.Or(syntactic_symbols.values()), - pp.one_of(['-', '+']) + pp.NotAny(pp.White() | pp.Char(syntactic_chars) | pp.StringEnd()), - # build character-by-character - pp.CharsNotIn(string.whitespace + syntactic_chars, exact=1) - ]))) - non_syntax_word.set_parse_action(lambda x: [Fragment(t) for t in x]) - non_syntax_word.set_name('non_syntax_word') - non_syntax_word.set_debug(False) - - # a word that can contain any character at all - greedily consumes syntax, so use with care - free_word = pp.CharsNotIn(string.whitespace).set_parse_action(lambda x: Fragment(x[0])) - free_word.set_name('free_word') - free_word.set_debug(False) - - - # ok here we go. forward declare some things.. - attention = pp.Forward() - cross_attention_substitute = pp.Forward() - parenthesized_fragment = pp.Forward() - quoted_fragment = pp.Forward() - - # the types of things that can go into a fragment, consisting of syntax-full and/or strictly syntax-free components - fragment_part_expressions = [ - attention, - cross_attention_substitute, - parenthesized_fragment, - quoted_fragment, - non_syntax_word - ] - # a fragment that is permitted to contain commas - fragment_including_commas = pp.ZeroOrMore(pp.MatchFirst( - fragment_part_expressions + [ - pp.Literal(',').set_parse_action(lambda x: Fragment(x[0])) - ] - )) - # a fragment that is not permitted to contain commas - fragment_excluding_commas = pp.ZeroOrMore(pp.MatchFirst( - fragment_part_expressions - )) - - # a fragment in double quotes (may be nested) - quoted_fragment << pp.QuotedString(quote_char='"', esc_char=None, esc_quote='\\"') - quoted_fragment.set_parse_action(lambda x: parse_fragment_str(x, fragment_including_commas, in_quotes=True)) - - # a fragment inside parentheses (may be nested) - parenthesized_fragment << (lparen + fragment_including_commas + rparen) - parenthesized_fragment.set_name('parenthesized_fragment') - parenthesized_fragment.set_debug(False) - - # a string of the form (= | | ) where keyword is alphanumeric + '_' - option = pp.Group(pp.MatchFirst([ - keyword + equals + (number | keyword), # option=value - number.copy().set_parse_action(pp.token_map(str)), # weight - keyword # flag - ])) - # options for an operator, eg "s_start=0.1, 0.3, no_normalize" - options = pp.Dict(pp.Optional(pp.delimited_list(option))) - options.set_name('options') - options.set_debug(False) - - # a fragment which can be used as the target for an operator - either quoted or in parentheses, or a bare vanilla word - potential_operator_target = (quoted_fragment | parenthesized_fragment | non_syntax_word) - - # a fragment whose weight has been increased or decreased by a given amount - attention_weight_operator = pp.Word('+') | pp.Word('-') | number - attention_explicit = ( - pp.Group(potential_operator_target) - + pp.Literal('.attend') - + lparen - + pp.Group(attention_weight_operator) - + rparen - ) - attention_explicit.set_parse_action(make_operator_object) - attention_implicit = ( - pp.Group(potential_operator_target) - + pp.NotAny(pp.White()) # do not permit whitespace between term and operator - + pp.Group(attention_weight_operator) - ) - attention_implicit.set_parse_action(lambda x: make_operator_object([x[0], '.attend', x[1]])) - attention << (attention_explicit | attention_implicit) - attention.set_name('attention') - attention.set_debug(False) - - # cross-attention control by swapping one fragment for another - cross_attention_substitute << ( - pp.Group(potential_operator_target).set_name('ca-target').set_debug(False) - + pp.Literal(".swap").set_name('ca-operator').set_debug(False) - + lparen - + pp.Group(fragment_excluding_commas).set_name('ca-replacement').set_debug(False) - + pp.Optional(comma + options).set_name('ca-options').set_debug(False) - + rparen - ) - cross_attention_substitute.set_name('cross_attention_substitute') - cross_attention_substitute.set_debug(False) - cross_attention_substitute.set_parse_action(make_operator_object) - - - # an entire self-contained prompt, which can be used in a Blend or Conjunction - prompt = pp.ZeroOrMore(pp.MatchFirst([ - cross_attention_substitute, - attention, - quoted_fragment, - parenthesized_fragment, - free_word, - pp.White().suppress() - ])) - quoted_prompt = quoted_fragment.copy().set_parse_action(lambda x: parse_fragment_str(x, prompt, in_quotes=True)) - - - # a blend/lerp between the feature vectors for two or more prompts - blend = ( - lparen - + pp.Group(pp.delimited_list(pp.Group(potential_operator_target | quoted_prompt), min=1)).set_name('bl-target').set_debug(False) - + rparen - + pp.Literal(".blend").set_name('bl-operator').set_debug(False) - + lparen - + pp.Group(options).set_name('bl-options').set_debug(False) - + rparen - ) - blend.set_name('blend') - blend.set_debug(False) - blend.set_parse_action(make_operator_object) - - # an operator to direct stable diffusion to step multiple times, once for each target, and then add the results together with different weights - explicit_conjunction = ( - lparen - + pp.Group(pp.delimited_list(pp.Group(potential_operator_target | quoted_prompt), min=1)).set_name('cj-target').set_debug(False) - + rparen - + pp.one_of([".and", ".add"]).set_name('cj-operator').set_debug(False) - + lparen - + pp.Group(options).set_name('cj-options').set_debug(False) - + rparen - ) - explicit_conjunction.set_name('explicit_conjunction') - explicit_conjunction.set_debug(False) - explicit_conjunction.set_parse_action(make_operator_object) - - # by default a prompt consists of a Conjunction with a single term - implicit_conjunction = (blend | pp.Group(prompt)) + pp.StringEnd() - implicit_conjunction.set_parse_action(lambda x: Conjunction(x)) - - conjunction = (explicit_conjunction | implicit_conjunction) - - return conjunction, prompt - - -def split_weighted_subprompts(text, skip_normalize=False)->list: - """ - Legacy blend parsing. - - grabs all text up to the first occurrence of ':' - uses the grabbed text as a sub-prompt, and takes the value following ':' as weight - if ':' has no value defined, defaults to 1.0 - repeats until no text remaining - """ - prompt_parser = re.compile(""" - (?P # capture group for 'prompt' - (?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:' - ) # end 'prompt' - (?: # non-capture group - :+ # match one or more ':' characters - (?P # capture group for 'weight' - -?\d+(?:\.\d+)? # match positive or negative integer or decimal number - )? # end weight capture group, make optional - \s* # strip spaces after weight - | # OR - $ # else, if no ':' then match end of line - ) # end non-capture group - """, re.VERBOSE) - parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float( - match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)] - if skip_normalize: - return parsed_prompts - weight_sum = sum(map(lambda x: x[1], parsed_prompts)) - if weight_sum == 0: - print( - "* Warning: Subprompt weights add up to zero. Discarding and using even weights instead.") - equal_weight = 1 / max(len(parsed_prompts), 1) - return [(x[0], equal_weight) for x in parsed_prompts] - return [(x[0], x[1] / weight_sum) for x in parsed_prompts] - diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index 37f0ebfa1d..25f7dc51c6 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -1,3 +1,8 @@ + +# adapted from bloc97's CrossAttentionControl colab +# https://github.com/bloc97/CrossAttentionControl + + import enum import math from typing import Optional, Callable @@ -6,35 +11,13 @@ import psutil import torch import diffusers from torch import nn + +from compel.cross_attention_control import Arguments from diffusers.models.unet_2d_condition import UNet2DConditionModel from diffusers.models.cross_attention import AttnProcessor from ldm.invoke.devices import torch_dtype -# adapted from bloc97's CrossAttentionControl colab -# https://github.com/bloc97/CrossAttentionControl - -class Arguments: - def __init__(self, edited_conditioning: torch.Tensor, edit_opcodes: list[tuple], edit_options: dict): - """ - :param edited_conditioning: if doing cross-attention control, the edited conditioning [1 x 77 x 768] - :param edit_opcodes: if doing cross-attention control, a list of difflib.SequenceMatcher-like opcodes describing how to map original conditioning tokens to edited conditioning tokens (only the 'equal' opcode is required) - :param edit_options: if doing cross-attention control, per-edit options. there should be 1 item in edit_options for each item in edit_opcodes. - """ - # todo: rewrite this to take embedding fragments rather than a single edited_conditioning vector - self.edited_conditioning = edited_conditioning - self.edit_opcodes = edit_opcodes - - if edited_conditioning is not None: - assert len(edit_opcodes) == len(edit_options), \ - "there must be 1 edit_options dict for each edit_opcodes tuple" - non_none_edit_options = [x for x in edit_options if x is not None] - assert len(non_none_edit_options)>0, "missing edit_options" - if len(non_none_edit_options)>1: - print('warning: cross-attention control options are not working properly for >1 edit') - self.edit_options = non_none_edit_options[0] - - class CrossAttentionType(enum.Enum): SELF = 1 TOKENS = 2 @@ -319,7 +302,6 @@ def override_cross_attention(model, context: Context, is_running_diffusers = Fal Inject attention parameters and functions into the passed in model to enable cross attention editing. :param model: The unet model to inject into. - :param cross_attention_control_args: Arugments passeed to the CrossAttentionControl implementations :return: None """ @@ -523,7 +505,7 @@ from dataclasses import field, dataclass import torch -from diffusers.models.cross_attention import CrossAttention, CrossAttnProcessor, SlicedAttnProcessor, AttnProcessor +from diffusers.models.cross_attention import CrossAttention, CrossAttnProcessor, SlicedAttnProcessor @dataclass diff --git a/ldm/modules/prompt_to_embeddings_converter.py b/ldm/modules/prompt_to_embeddings_converter.py deleted file mode 100644 index 84d927d48b..0000000000 --- a/ldm/modules/prompt_to_embeddings_converter.py +++ /dev/null @@ -1,236 +0,0 @@ -import math - -import torch -from transformers import CLIPTokenizer, CLIPTextModel - -from ldm.invoke.devices import torch_dtype -from ldm.modules.textual_inversion_manager import TextualInversionManager - - -class WeightedPromptFragmentsToEmbeddingsConverter(): - - def __init__(self, - tokenizer: CLIPTokenizer, # converts strings to lists of int token ids - text_encoder: CLIPTextModel, # convert a list of int token ids to a tensor of embeddings - textual_inversion_manager: TextualInversionManager = None - ): - self.tokenizer = tokenizer - self.text_encoder = text_encoder - self.textual_inversion_manager = textual_inversion_manager - - @property - def max_length(self): - return self.tokenizer.model_max_length - - def get_embeddings_for_weighted_prompt_fragments(self, - text: list[list[str]], - fragment_weights: list[list[float]], - should_return_tokens: bool = False, - device='cpu' - ) -> torch.Tensor: - ''' - - :param text: A list of fragments of text to which different weights are to be applied. - :param fragment_weights: A batch of lists of weights, one for each entry in `fragments`. - :return: A tensor of shape `[1, 77, token_dim]` containing weighted embeddings where token_dim is 768 for SD1 - and 1280 for SD2 - ''' - if len(text) != len(fragment_weights): - raise ValueError(f"lengths of text and fragment_weights lists are not the same ({len(text)} != {len(fragment_weights)})") - - batch_z = None - batch_tokens = None - for fragments, weights in zip(text, fragment_weights): - - # First, weight tokens in individual fragments by scaling the feature vectors as requested (effectively - # applying a multiplier to the CFG scale on a per-token basis). - # For tokens weighted<1, intuitively we want SD to become not merely *less* interested in the concept - # captured by the fragment but actually *dis*interested in it (a 0.01 interest in "red" is still an active - # interest, however small, in redness; what the user probably intends when they attach the number 0.01 to - # "red" is to tell SD that it should almost completely *ignore* redness). - # To do this, the embedding is lerped away from base_embedding in the direction of an embedding for a prompt - # string from which the low-weighted fragment has been simply removed. The closer the weight is to zero, the - # closer the resulting embedding is to an embedding for a prompt that simply lacks this fragment. - - # handle weights >=1 - tokens, per_token_weights = self.get_token_ids_and_expand_weights(fragments, weights, device=device) - base_embedding = self.build_weighted_embedding_tensor(tokens, per_token_weights) - - # this is our starting point - embeddings = base_embedding.unsqueeze(0) - per_embedding_weights = [1.0] - - # now handle weights <1 - # Do this by building extra embeddings tensors that lack the words being <1 weighted. These will be lerped - # with the embeddings tensors that have the words, such that if the weight of a word is 0.5, the resulting - # embedding will be exactly half-way between the unweighted prompt and the prompt with the <1 weighted words - # removed. - # eg for "mountain:1 man:0.5", intuitively the "man" should be "half-gone". therefore, append an embedding - # for "mountain" (i.e. without "man") to the already-produced embedding for "mountain man", and weight it - # such that the resulting lerped embedding is exactly half-way between "mountain man" and "mountain". - for index, fragment_weight in enumerate(weights): - if fragment_weight < 1: - fragments_without_this = fragments[:index] + fragments[index+1:] - weights_without_this = weights[:index] + weights[index+1:] - tokens, per_token_weights = self.get_token_ids_and_expand_weights(fragments_without_this, weights_without_this, device=device) - embedding_without_this = self.build_weighted_embedding_tensor(tokens, per_token_weights) - - embeddings = torch.cat((embeddings, embedding_without_this.unsqueeze(0)), dim=1) - # weight of the embedding *without* this fragment gets *stronger* as its weight approaches 0 - # if fragment_weight = 0, basically we want embedding_without_this to completely overwhelm base_embedding - # therefore: - # fragment_weight = 1: we are at base_z => lerp weight 0 - # fragment_weight = 0.5: we are halfway between base_z and here => lerp weight 1 - # fragment_weight = 0: we're now entirely overriding base_z ==> lerp weight inf - # so let's use tan(), because: - # tan is 0.0 at 0, - # 1.0 at PI/4, and - # inf at PI/2 - # -> tan((1-weight)*PI/2) should give us ideal lerp weights - epsilon = 1e-9 - fragment_weight = max(epsilon, fragment_weight) # inf is bad - embedding_lerp_weight = math.tan((1.0 - fragment_weight) * math.pi / 2) - # todo handle negative weight? - - per_embedding_weights.append(embedding_lerp_weight) - - lerped_embeddings = self.apply_embedding_weights(embeddings, per_embedding_weights, normalize=True).squeeze(0) - - #print(f"assembled tokens for '{fragments}' into tensor of shape {lerped_embeddings.shape}") - - # append to batch - batch_z = lerped_embeddings.unsqueeze(0) if batch_z is None else torch.cat([batch_z, lerped_embeddings.unsqueeze(0)], dim=1) - batch_tokens = tokens.unsqueeze(0) if batch_tokens is None else torch.cat([batch_tokens, tokens.unsqueeze(0)], dim=1) - - # should have shape (B, 77, 768) - #print(f"assembled all tokens into tensor of shape {batch_z.shape}") - - if should_return_tokens: - return batch_z, batch_tokens - else: - return batch_z - - def get_token_ids(self, fragments: list[str], include_start_and_end_markers: bool = True) -> list[list[int]]: - """ - Convert a list of strings like `["a cat", "sitting", "on a mat"]` into a list of lists of token ids like - `[[bos, 0, 1, eos], [bos, 2, eos], [bos, 3, 0, 4, eos]]`. bos/eos markers are skipped if - `include_start_and_end_markers` is `False`. Each list will be restricted to the maximum permitted length - (typically 75 tokens + eos/bos markers). - - :param fragments: The strings to convert. - :param include_start_and_end_markers: - :return: - """ - # for args documentation see ENCODE_KWARGS_DOCSTRING in tokenization_utils_base.py (in `transformers` lib) - token_ids_list = self.tokenizer( - fragments, - truncation=True, - max_length=self.max_length, - return_overflowing_tokens=False, - padding='do_not_pad', - return_tensors=None, # just give me lists of ints - )['input_ids'] - - result = [] - for token_ids in token_ids_list: - # trim eos/bos - token_ids = token_ids[1:-1] - # pad for textual inversions with vector length >1 - token_ids = self.textual_inversion_manager.expand_textual_inversion_token_ids_if_necessary(token_ids) - # restrict length to max_length-2 (leaving room for bos/eos) - token_ids = token_ids[0:self.max_length - 2] - # add back eos/bos if requested - if include_start_and_end_markers: - token_ids = [self.tokenizer.bos_token_id] + token_ids + [self.tokenizer.eos_token_id] - - result.append(token_ids) - - return result - - - @classmethod - def apply_embedding_weights(self, embeddings: torch.Tensor, per_embedding_weights: list[float], normalize:bool) -> torch.Tensor: - per_embedding_weights = torch.tensor(per_embedding_weights, dtype=embeddings.dtype, device=embeddings.device) - if normalize: - per_embedding_weights = per_embedding_weights / torch.sum(per_embedding_weights) - reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1, 1,)) - #reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1,1,)).expand(embeddings.shape) - return torch.sum(embeddings * reshaped_weights, dim=1) - # lerped embeddings has shape (77, 768) - - - def get_token_ids_and_expand_weights(self, fragments: list[str], weights: list[float], device: str) -> (torch.Tensor, torch.Tensor): - ''' - Given a list of text fragments and corresponding weights: tokenize each fragment, append the token sequences - together and return a padded token sequence starting with the bos marker, ending with the eos marker, and padded - or truncated as appropriate to `self.max_length`. Also return a list of weights expanded from the passed-in - weights to match each token. - - :param fragments: Text fragments to tokenize and concatenate. May be empty. - :param weights: Per-fragment weights (i.e. quasi-CFG scaling). Values from 0 to inf are permitted. In practise with SD1.5 - values >1.6 tend to produce garbage output. Must have same length as `fragment`. - :return: A tuple of tensors `(token_ids, weights)`. `token_ids` is ints, `weights` is floats, both have shape `[self.max_length]`. - ''' - if len(fragments) != len(weights): - raise ValueError(f"lengths of text and fragment_weights lists are not the same ({len(fragments)} != {len(weights)})") - - # empty is meaningful - if len(fragments) == 0: - fragments = [''] - weights = [1.0] - per_fragment_token_ids = self.get_token_ids(fragments, include_start_and_end_markers=False) - all_token_ids = [] - per_token_weights = [] - #print("all fragments:", fragments, weights) - for this_fragment_token_ids, weight in zip(per_fragment_token_ids, weights): - # append - all_token_ids += this_fragment_token_ids - # fill out weights tensor with one float per token - per_token_weights += [float(weight)] * len(this_fragment_token_ids) - - # leave room for bos/eos - max_token_count_without_bos_eos_markers = self.max_length - 2 - if len(all_token_ids) > max_token_count_without_bos_eos_markers: - excess_token_count = len(all_token_ids) - max_token_count_without_bos_eos_markers - # TODO build nice description string of how the truncation was applied - # this should be done by calling self.tokenizer.convert_ids_to_tokens() then passing the result to - # self.tokenizer.convert_tokens_to_string() for the token_ids on each side of the truncation limit. - print(f">> Prompt is {excess_token_count} token(s) too long and has been truncated") - all_token_ids = all_token_ids[0:max_token_count_without_bos_eos_markers] - per_token_weights = per_token_weights[0:max_token_count_without_bos_eos_markers] - - # pad out to a self.max_length-entry array: [bos_token, , eos_token, pad_token…] - # (typically self.max_length == 77) - all_token_ids = [self.tokenizer.bos_token_id] + all_token_ids + [self.tokenizer.eos_token_id] - per_token_weights = [1.0] + per_token_weights + [1.0] - pad_length = self.max_length - len(all_token_ids) - all_token_ids += [self.tokenizer.pad_token_id] * pad_length - per_token_weights += [1.0] * pad_length - - all_token_ids_tensor = torch.tensor(all_token_ids, dtype=torch.long, device=device) - per_token_weights_tensor = torch.tensor(per_token_weights, dtype=torch_dtype(self.text_encoder.device), device=device) - #print(f"assembled all_token_ids_tensor with shape {all_token_ids_tensor.shape}") - return all_token_ids_tensor, per_token_weights_tensor - - def build_weighted_embedding_tensor(self, token_ids: torch.Tensor, per_token_weights: torch.Tensor) -> torch.Tensor: - ''' - Build a tensor that embeds the passed-in token IDs and applies the given per_token weights - :param token_ids: A tensor of shape `[self.max_length]` containing token IDs (ints) - :param per_token_weights: A tensor of shape `[self.max_length]` containing weights (floats) - :return: A tensor of shape `[1, self.max_length, token_dim]` representing the requested weighted embeddings - where `token_dim` is 768 for SD1 and 1280 for SD2. - ''' - #print(f"building weighted embedding tensor for {tokens} with weights {per_token_weights}") - if token_ids.shape != torch.Size([self.max_length]): - raise ValueError(f"token_ids has shape {token_ids.shape} - expected [{self.max_length}]") - - z = self.text_encoder(token_ids.unsqueeze(0), return_dict=False)[0] - empty_token_ids = torch.tensor([self.tokenizer.bos_token_id] + - [self.tokenizer.pad_token_id] * (self.max_length-2) + - [self.tokenizer.eos_token_id], dtype=torch.int, device=z.device).unsqueeze(0) - empty_z = self.text_encoder(empty_token_ids).last_hidden_state - batch_weights_expanded = per_token_weights.reshape(per_token_weights.shape + (1,)).expand(z.shape).to(z) - z_delta_from_empty = z - empty_z - weighted_z = empty_z + (z_delta_from_empty * batch_weights_expanded) - - return weighted_z diff --git a/ldm/modules/textual_inversion_manager.py b/ldm/modules/textual_inversion_manager.py index b154a95f27..f9f59585f3 100644 --- a/ldm/modules/textual_inversion_manager.py +++ b/ldm/modules/textual_inversion_manager.py @@ -8,6 +8,7 @@ import torch from picklescan.scanner import scan_file_path from transformers import CLIPTextModel, CLIPTokenizer +from compel.embeddings_provider import BaseTextualInversionManager from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary @@ -23,7 +24,7 @@ class TextualInversion: return self.embedding.shape[0] -class TextualInversionManager: +class TextualInversionManager(BaseTextualInversionManager): def __init__( self, tokenizer: CLIPTokenizer, @@ -105,7 +106,7 @@ class TextualInversionManager: def _add_textual_inversion( self, trigger_str, embedding, defer_injecting_tokens=False - ) -> TextualInversion: + ) -> Optional[TextualInversion]: """ Add a textual inversion to be recognised. :param trigger_str: The trigger text in the prompt that activates this textual inversion. If unknown to the embedder's tokenizer, will be added. diff --git a/pyproject.toml b/pyproject.toml index f3dfa69b91..2c7c979ada 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ dependencies = [ "albumentations", "click", "clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip", + "compel", "datasets", "diffusers[torch]~=0.11", "dnspython==2.2.1", diff --git a/tests/test_prompt_parser.py b/tests/test_prompt_parser.py deleted file mode 100644 index 0c9bbf91f9..0000000000 --- a/tests/test_prompt_parser.py +++ /dev/null @@ -1,499 +0,0 @@ -import unittest - -import pyparsing - -from ldm.invoke.prompt_parser import PromptParser, Blend, Conjunction, FlattenedPrompt, CrossAttentionControlSubstitute, \ - Fragment - - -def parse_prompt(prompt_string): - pp = PromptParser() - #print(f"parsing '{prompt_string}'") - parse_result = pp.parse_conjunction(prompt_string) - #print(f"-> parsed '{prompt_string}' to {parse_result}") - return parse_result - -def make_basic_conjunction(strings: list[str]): - fragments = [Fragment(x) for x in strings] - return Conjunction([FlattenedPrompt(fragments)]) - -def make_weighted_conjunction(weighted_strings: list[tuple[str,float]]): - fragments = [Fragment(x, w) for x,w in weighted_strings] - return Conjunction([FlattenedPrompt(fragments)]) - - -class PromptParserTestCase(unittest.TestCase): - - def test_empty(self): - self.assertEqual(make_weighted_conjunction([('', 1)]), parse_prompt('')) - - def test_basic(self): - self.assertEqual(make_weighted_conjunction([("fire flames", 1)]), parse_prompt("fire flames")) - self.assertEqual(make_weighted_conjunction([('fire flames', 1)]), parse_prompt("fire (flames)")) - self.assertEqual(make_weighted_conjunction([("fire, flames", 1)]), parse_prompt("fire, flames")) - self.assertEqual(make_weighted_conjunction([("fire, flames , fire", 1)]), parse_prompt("fire, flames , fire")) - self.assertEqual(make_weighted_conjunction([("cat hot-dog eating", 1)]), parse_prompt("cat hot-dog eating")) - self.assertEqual(make_basic_conjunction(['Dalí']), parse_prompt("Dalí")) - - def test_attention(self): - self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("(flames)0.5")) - self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("(flames).attend(0.5)")) - self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("flames.attend(0.5)")) - self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("\"flames\".attend(0.5)")) - self.assertEqual(make_weighted_conjunction([('fire flames', 0.5)]), parse_prompt("(fire flames)0.5")) - self.assertEqual(make_weighted_conjunction([('fire flames', 0.5)]), parse_prompt("(fire flames).attend(0.5)")) - - self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("(flames)+")) - self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("flames+")) - self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("\"flames\"+")) - self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("flames.attend(+)")) - self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("(flames).attend(+)")) - self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("\"flames\".attend(+)")) - self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("(flames)-")) - self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("flames-")) - self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("\"flames\"-")) - self.assertEqual(make_weighted_conjunction([('fire', 1), ('flames', 0.5)]), parse_prompt("fire (flames)0.5")) - self.assertEqual(make_weighted_conjunction([('fire', 1), ('flames', 0.5)]), parse_prompt("fire flames.attend(0.5)")) - self.assertEqual(make_weighted_conjunction([('fire', 1), ('flames', 0.5)]), parse_prompt("fire (flames).attend(0.5)")) - self.assertEqual(make_weighted_conjunction([('fire', 1), ('flames', 0.5)]), parse_prompt("fire \"flames\".attend(0.5)")) - self.assertEqual(make_weighted_conjunction([('flames', pow(1.1, 2))]), parse_prompt("(flames)++")) - self.assertEqual(make_weighted_conjunction([('flames', pow(0.9, 2))]), parse_prompt("(flames)--")) - self.assertEqual(make_weighted_conjunction([('flowers', pow(0.9, 3)), ('flames', pow(1.1, 3))]), parse_prompt("(flowers)--- flames+++")) - self.assertEqual(make_weighted_conjunction([('pretty flowers', 1.1)]), - parse_prompt("(pretty flowers)+")) - self.assertEqual(make_weighted_conjunction([('pretty flowers', 1.1), (', the flames are too hot', 1)]), - parse_prompt("(pretty flowers)+, the flames are too hot")) - - def test_no_parens_attention_runon(self): - self.assertEqual(make_weighted_conjunction([('fire', 1.0), ('flames', pow(1.1, 2))]), parse_prompt("fire flames++")) - self.assertEqual(make_weighted_conjunction([('fire', 1.0), ('flames', pow(0.9, 2))]), parse_prompt("fire flames--")) - self.assertEqual(make_weighted_conjunction([('flowers', 1.0), ('fire', pow(1.1, 2)), ('flames', 1.0)]), parse_prompt("flowers fire++ flames")) - self.assertEqual(make_weighted_conjunction([('flowers', 1.0), ('fire', pow(0.9, 2)), ('flames', 1.0)]), parse_prompt("flowers fire-- flames")) - - - def test_explicit_conjunction(self): - self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])]), parse_prompt('("fire", "flames").and(1,1)')) - self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])]), parse_prompt('("fire", "flames").and()')) - self.assertEqual( - Conjunction([FlattenedPrompt([('fire flames', 1.0)]), FlattenedPrompt([('mountain man', 1.0)])]), parse_prompt('("fire flames", "mountain man").and()')) - self.assertEqual(Conjunction([FlattenedPrompt([('fire', 2.0)]), FlattenedPrompt([('flames', 0.9)])]), parse_prompt('("(fire)2.0", "flames-").and()')) - self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)]), - FlattenedPrompt([('mountain man', 1.0)])]), parse_prompt('("fire", "flames", "mountain man").and()')) - - def test_conjunction_weights(self): - self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])], weights=[2.0,1.0]), parse_prompt('("fire", "flames").and(2,1)')) - self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])], weights=[1.0,2.0]), parse_prompt('("fire", "flames").and(1,2)')) - - with self.assertRaises(PromptParser.ParsingException): - parse_prompt('("fire", "flames").and(2)') - parse_prompt('("fire", "flames").and(2,1,2)') - - def test_complex_conjunction(self): - - #print(parse_prompt("a person with a hat (riding a bicycle.swap(skateboard))++")) - - self.assertEqual(Conjunction([FlattenedPrompt([("mountain man", 1.0)]), FlattenedPrompt([("a person with a hat", 1.0), ("riding a bicycle", pow(1.1,2))])], weights=[0.5, 0.5]), - parse_prompt("(\"mountain man\", \"a person with a hat (riding a bicycle)++\").and(0.5, 0.5)")) - self.assertEqual(Conjunction([FlattenedPrompt([("mountain man", 1.0)]), - FlattenedPrompt([("a person with a hat", 1.0), - ("riding a", 1.1*1.1), - CrossAttentionControlSubstitute( - [Fragment("bicycle", pow(1.1,2))], - [Fragment("skateboard", pow(1.1,2))]) - ]) - ], weights=[0.5, 0.5]), - parse_prompt("(\"mountain man\", \"a person with a hat (riding a bicycle.swap(skateboard))++\").and(0.5, 0.5)")) - - def test_badly_formed(self): - def make_untouched_prompt(prompt): - return Conjunction([FlattenedPrompt([(prompt, 1.0)])]) - - def assert_if_prompt_string_not_untouched(prompt): - self.assertEqual(make_untouched_prompt(prompt), parse_prompt(prompt)) - - assert_if_prompt_string_not_untouched('a test prompt') - assert_if_prompt_string_not_untouched('a badly formed +test prompt') - assert_if_prompt_string_not_untouched('a badly (formed test prompt') - - #with self.assertRaises(pyparsing.ParseException): - assert_if_prompt_string_not_untouched('a badly (formed +test prompt') - self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a badly formed +test prompt',1)])]) , parse_prompt('a badly (formed +test )prompt')) - self.assertEqual(Conjunction([FlattenedPrompt([Fragment('(((a badly formed +test prompt',1)])]) , parse_prompt('(((a badly (formed +test )prompt')) - - self.assertEqual(Conjunction([FlattenedPrompt([Fragment('(a ba dly f ormed +test prompt',1)])]) , parse_prompt('(a (ba)dly (f)ormed +test prompt')) - self.assertEqual(Conjunction([FlattenedPrompt([Fragment('(a ba dly f ormed +test +prompt',1)])]) , parse_prompt('(a (ba)dly (f)ormed +test +prompt')) - self.assertEqual(Conjunction([Blend([FlattenedPrompt([Fragment('((a badly (formed +test', 1)])], [1.0])]), - parse_prompt('("((a badly (formed +test ").blend(1.0)')) - - self.assertEqual(Conjunction([FlattenedPrompt([Fragment('hamburger bun', 1)])]), - parse_prompt("hamburger ((bun))")) - self.assertEqual(Conjunction([FlattenedPrompt([Fragment('hamburger bun', 1)])]), - parse_prompt("hamburger (bun)")) - self.assertEqual(Conjunction([FlattenedPrompt([Fragment('hamburger kaiser roll', 1)])]), - parse_prompt("hamburger (kaiser roll)")) - self.assertEqual(Conjunction([FlattenedPrompt([Fragment('hamburger kaiser roll', 1)])]), - parse_prompt("hamburger ((kaiser roll))")) - - - def test_blend(self): - self.assertEqual(Conjunction( - [Blend([FlattenedPrompt([('mountain', 1.0)]), FlattenedPrompt([('man', 1.0)])], [1.0, 1.0])]), - parse_prompt("(\"mountain\", \"man\").blend()") - ) - self.assertEqual(Conjunction( - [Blend([FlattenedPrompt([('mountain', 1.0)]), FlattenedPrompt([('man', 1.0)])], [1.0, 1.0])]), - parse_prompt("(mountain, man).blend()") - ) - self.assertEqual(Conjunction( - [Blend([FlattenedPrompt([('mountain', 1.0)]), FlattenedPrompt([('man', 1.0)])], [1.0, 1.0])]), - parse_prompt("((mountain), (man)).blend()") - ) - self.assertEqual(Conjunction( - [Blend([FlattenedPrompt([('mountain', 1.0)]), FlattenedPrompt([('tall man', 1.0)])], [1.0, 1.0])]), - parse_prompt("((mountain), (tall man)).blend()") - ) - - with self.assertRaises(PromptParser.ParsingException): - print(parse_prompt("((mountain), \"cat.swap(dog)\").blend()")) - - self.assertEqual(Conjunction( - [Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('fire flames', 1.0)])], [0.7, 0.3])]), - parse_prompt("(\"fire\", \"fire flames\").blend(0.7, 0.3)") - ) - self.assertEqual(Conjunction([Blend( - [FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('fire flames', 1.0)]), FlattenedPrompt([('hi', 1.0)])], - [0.7, 0.3, 1.0])]), - parse_prompt("(\"fire\", \"fire flames\", \"hi\").blend(0.7, 0.3, 1.0)") - ) - self.assertEqual(Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), - FlattenedPrompt([('fire flames', 1.0), ('hot', pow(1.1, 2))]), - FlattenedPrompt([('hi', 1.0)])], - weights=[0.7, 0.3, 1.0])]), - parse_prompt("(\"fire\", \"fire flames (hot)++\", \"hi\").blend(0.7, 0.3, 1.0)") - ) - # blend a single entry is not a failure - self.assertEqual(Conjunction([Blend([FlattenedPrompt([('fire', 1.0)])], [0.7])]), - parse_prompt("(\"fire\").blend(0.7)") - ) - # blend with empty - self.assertEqual( - Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]), - parse_prompt("(\"fire\", \"\").blend(0.7, 1)") - ) - self.assertEqual( - Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]), - parse_prompt("(\"fire\", \" \").blend(0.7, 1)") - ) - self.assertEqual( - Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]), - parse_prompt("(\"fire\", \" \").blend(0.7, 1)") - ) - self.assertEqual( - Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([(',', 1.0)])], [0.7, 1.0])]), - parse_prompt("(\"fire\", \" , \").blend(0.7, 1)") - ) - - self.assertEqual( - Conjunction([Blend([FlattenedPrompt([('mountain , man , hairy', 1)]), - FlattenedPrompt([('face , teeth ,', 1), ('eyes', 0.9*0.9)])], weights=[1.0,-1.0], normalize_weights=True)]), - parse_prompt('("mountain, man, hairy", "face, teeth, eyes--").blend(1,-1)') - ) - self.assertEqual( - Conjunction([Blend([FlattenedPrompt([('mountain , man , hairy', 1)]), - FlattenedPrompt([('face , teeth ,', 1), ('eyes', 0.9 * 0.9)])], weights=[1.0, -1.0], normalize_weights=False)]), - parse_prompt('("mountain, man, hairy", "face, teeth, eyes--").blend(1,-1,no_normalize)') - ) - - with self.assertRaises(PromptParser.ParsingException): - parse_prompt("(\"fire\", \"fire flames\").blend(0.7, 0.3, 0.1)") - with self.assertRaises(PromptParser.ParsingException): - parse_prompt("(\"fire\", \"fire flames\").blend(0.7)") - - - def test_nested(self): - self.assertEqual(make_weighted_conjunction([('fire', 1.0), ('flames', 2.0), ('trees', 3.0)]), - parse_prompt('fire (flames (trees)1.5)2.0')) - self.assertEqual(Conjunction([Blend(prompts=[FlattenedPrompt([('fire', 1.0), ('flames', 1.2100000000000002)]), - FlattenedPrompt([('mountain', 1.0), ('man', 2.0)])], - weights=[1.0, 1.0])]), - parse_prompt('("fire (flames)++", "mountain (man)2").blend(1,1)')) - - def test_cross_attention_control(self): - - self.assertEqual(Conjunction([FlattenedPrompt([CrossAttentionControlSubstitute([Fragment('sun')], [Fragment('moon')])])]), - parse_prompt("sun.swap(moon)")) - - self.assertEqual(Conjunction([ - FlattenedPrompt([Fragment('a', 1), - CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)]), - Fragment('eating a hotdog', 1)])]), parse_prompt("a \"cat\".swap(dog) eating a hotdog")) - - self.assertEqual(Conjunction([ - FlattenedPrompt([Fragment('a', 1), - CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)]), - Fragment('eating a hotdog', 1)])]), parse_prompt("a cat.swap(dog) eating a hotdog")) - - - fire_flames_to_trees = Conjunction([FlattenedPrompt([('fire', 1.0), \ - CrossAttentionControlSubstitute([Fragment('flames', 1)], [Fragment('trees', 1)])])]) - self.assertEqual(fire_flames_to_trees, parse_prompt('fire "flames".swap(trees)')) - self.assertEqual(fire_flames_to_trees, parse_prompt('fire (flames).swap(trees)')) - self.assertEqual(fire_flames_to_trees, parse_prompt('fire ("flames").swap(trees)')) - self.assertEqual(fire_flames_to_trees, parse_prompt('fire "flames".swap("trees")')) - self.assertEqual(fire_flames_to_trees, parse_prompt('fire (flames).swap("trees")')) - self.assertEqual(fire_flames_to_trees, parse_prompt('fire ("flames").swap("trees")')) - - fire_flames_to_trees_and_houses = Conjunction([FlattenedPrompt([('fire', 1.0), \ - CrossAttentionControlSubstitute([Fragment('flames', 1)], [Fragment('trees and houses', 1)])])]) - self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire ("flames").swap("trees and houses")')) - self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire (flames).swap("trees and houses")')) - self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire "flames".swap("trees and houses")')) - - trees_and_houses_to_flames = Conjunction([FlattenedPrompt([('fire', 1.0), \ - CrossAttentionControlSubstitute([Fragment('trees and houses', 1)], [Fragment('flames',1)])])]) - self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire ("trees and houses").swap("flames")')) - self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire (trees and houses).swap("flames")')) - self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire "trees and houses".swap("flames")')) - self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire ("trees and houses").swap(flames)')) - self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire (trees and houses).swap(flames)')) - self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire "trees and houses".swap(flames)')) - - flames_to_trees_fire = Conjunction([FlattenedPrompt([ - CrossAttentionControlSubstitute([Fragment('flames',1)], [Fragment('trees',1)]), - (', fire', 1.0)])]) - self.assertEqual(flames_to_trees_fire, parse_prompt('"flames".swap("trees"), fire')) - self.assertEqual(flames_to_trees_fire, parse_prompt('(flames).swap("trees"), fire')) - self.assertEqual(flames_to_trees_fire, parse_prompt('("flames").swap("trees"), fire')) - self.assertEqual(flames_to_trees_fire, parse_prompt('"flames".swap(trees), fire')) - self.assertEqual(flames_to_trees_fire, parse_prompt('(flames).swap(trees), fire ')) - self.assertEqual(flames_to_trees_fire, parse_prompt('("flames").swap(trees), fire ')) - - - self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1), - CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]), - parse_prompt('a forest landscape "".swap("in winter")')) - self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1), - CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]), - parse_prompt('a forest landscape ().swap(in winter)')) - self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1), - CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]), - parse_prompt('a forest landscape " ".swap("in winter")')) - - self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1), - CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]), - parse_prompt('a forest landscape "in winter".swap("")')) - self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1), - CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]), - parse_prompt('a forest landscape "in winter".swap()')) - self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1), - CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]), - parse_prompt('a forest landscape "in winter".swap(" ")')) - - def test_cross_attention_control_with_attention(self): - flames_to_trees_fire = Conjunction([FlattenedPrompt([ - CrossAttentionControlSubstitute([Fragment('flames',0.5)], [Fragment('trees',0.7)]), - Fragment(',', 1), Fragment('fire', 2.0)])]) - self.assertEqual(flames_to_trees_fire, parse_prompt('"(flames)0.5".swap("(trees)0.7"), (fire)2.0')) - flames_to_trees_fire = Conjunction([FlattenedPrompt([ - CrossAttentionControlSubstitute([Fragment('fire',0.5), Fragment('flames',0.25)], [Fragment('trees',0.7)]), - Fragment(',', 1), Fragment('fire', 2.0)])]) - self.assertEqual(flames_to_trees_fire, parse_prompt('"(fire (flames)0.5)0.5".swap("(trees)0.7"), (fire)2.0')) - flames_to_trees_fire = Conjunction([FlattenedPrompt([ - CrossAttentionControlSubstitute([Fragment('fire',0.5), Fragment('flames',0.25)], [Fragment('trees',0.7), Fragment('houses', 1)]), - Fragment(',', 1), Fragment('fire', 2.0)])]) - self.assertEqual(flames_to_trees_fire, parse_prompt('"(fire (flames)0.5)0.5".swap("(trees)0.7 houses"), (fire)2.0')) - - self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1), - CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]), - Fragment('eating a', 1), - CrossAttentionControlSubstitute([Fragment('hotdog',1)], [Fragment('hotdog', pow(1.1,4))]) - ])]), - parse_prompt("a cat.swap(dog) eating a hotdog.swap(hotdog++++)")) - self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1), - CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]), - Fragment('eating a', 1), - CrossAttentionControlSubstitute([Fragment('hotdog',1)], [Fragment('hotdog', pow(1.1,4))]) - ])]), - parse_prompt("a cat.swap(dog) eating a hotdog.swap(hotdog++++, shape_freedom=0.5)")) - - self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1), - CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]), - Fragment('eating a', 1), - CrossAttentionControlSubstitute([Fragment('hotdog',1)], [Fragment('hotdog', pow(1.1,4))]) - ])]), - parse_prompt("a cat.swap(dog) eating a hotdog.swap(\"hotdog++++\", shape_freedom=0.5)")) - - self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1), - CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]), - Fragment('eating a', 1), - CrossAttentionControlSubstitute([Fragment('hotdog',1)], [Fragment('h(o)tdog', pow(1.1,4))]) - ])]), - parse_prompt("a cat.swap(dog) eating a hotdog.swap(h\(o\)tdog++++, shape_freedom=0.5)")) - self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1), - CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]), - Fragment('eating a', 1), - CrossAttentionControlSubstitute([Fragment('hotdog',1)], [Fragment('h(o)tdog', pow(1.1,4))]) - ])]), - parse_prompt("a cat.swap(dog) eating a hotdog.swap(\"h\(o\)tdog++++\", shape_freedom=0.5)")) - - self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1), - CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]), - Fragment('eating a', 1), - CrossAttentionControlSubstitute([Fragment('hotdog',1)], [Fragment('h(o)tdog', pow(0.9,1))]) - ])]), - parse_prompt("a cat.swap(dog) eating a hotdog.swap(h\(o\)tdog-, shape_freedom=0.5)")) - - - def test_cross_attention_control_options(self): - self.assertEqual(Conjunction([ - FlattenedPrompt([Fragment('a', 1), - CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)], options={'s_start':0.1}), - Fragment('eating a hotdog', 1)])]), - parse_prompt("a \"cat\".swap(dog, s_start=0.1) eating a hotdog")) - self.assertEqual(Conjunction([ - FlattenedPrompt([Fragment('a', 1), - CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)], options={'t_start':0.1}), - Fragment('eating a hotdog', 1)])]), - parse_prompt("a \"cat\".swap(dog, t_start=0.1) eating a hotdog")) - self.assertEqual(Conjunction([ - FlattenedPrompt([Fragment('a', 1), - CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)], options={'s_start': 20.0, 't_start':0.1}), - Fragment('eating a hotdog', 1)])]), - parse_prompt("a \"cat\".swap(dog, t_start=0.1, s_start=20) eating a hotdog")) - - self.assertEqual( - Conjunction([ - FlattenedPrompt([Fragment('a fantasy forest landscape', 1), - CrossAttentionControlSubstitute([Fragment('', 1)], [Fragment('with a river', 1)], - options={'s_start': 0.8, 't_start': 0.8})])]), - parse_prompt("a fantasy forest landscape \"\".swap(with a river, s_start=0.8, t_start=0.8)")) - - - def test_escaping(self): - - # make sure ", ( and ) can be escaped - - self.assertEqual(make_basic_conjunction(['mountain (man)']),parse_prompt('mountain \(man\)')) - self.assertEqual(make_basic_conjunction(['mountain (man )']),parse_prompt('mountain (\(man)\)')) - self.assertEqual(make_basic_conjunction(['mountain (man)']),parse_prompt('mountain (\(man\))')) - self.assertEqual(make_weighted_conjunction([('mountain', 1), ('(man)', 1.1)]), parse_prompt('mountain (\(man\))+')) - self.assertEqual(make_weighted_conjunction([('mountain', 1), ('(man)', 1.1)]), parse_prompt('"mountain" (\(man\))+')) - self.assertEqual(make_weighted_conjunction([('"mountain"', 1), ('(man)', 1.1)]), parse_prompt('\\"mountain\\" (\(man\))+')) - # same weights for each are combined into one - self.assertEqual(make_weighted_conjunction([('"mountain" (man)', 1.1)]), parse_prompt('(\\"mountain\\")+ (\(man\))+')) - self.assertEqual(make_weighted_conjunction([('"mountain"', 1.1), ('(man)', 0.9)]), parse_prompt('(\\"mountain\\")+ (\(man\))-')) - - self.assertEqual(make_weighted_conjunction([('mountain', 1), ('\(man\)', 1.1)]),parse_prompt('mountain (\(man\))1.1')) - self.assertEqual(make_weighted_conjunction([('mountain', 1), ('\(man\)', 1.1)]),parse_prompt('"mountain" (\(man\))1.1')) - self.assertEqual(make_weighted_conjunction([('"mountain"', 1), ('\(man\)', 1.1)]),parse_prompt('\\"mountain\\" (\(man\))1.1')) - # same weights for each are combined into one - self.assertEqual(make_weighted_conjunction([('\\"mountain\\" \(man\)', 1.1)]),parse_prompt('(\\"mountain\\")+ (\(man\))1.1')) - self.assertEqual(make_weighted_conjunction([('\\"mountain\\"', 1.1), ('\(man\)', 0.9)]),parse_prompt('(\\"mountain\\")1.1 (\(man\))0.9')) - - self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain', 1.1), ('\(man\)', 1.1*1.1)]),parse_prompt('hairy (mountain (\(man\))+)+')) - self.assertEqual(make_weighted_conjunction([('hairy', 1), ('\(man\)', 1.1*1.1), ('mountain', 1.1)]),parse_prompt('hairy ((\(man\))1.1 "mountain")+')) - self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain', 1.1), ('\(man\)', 1.1*1.1)]),parse_prompt('hairy ("mountain" (\(man\))1.1 )+')) - self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , man', 1.1)]),parse_prompt('hairy ("mountain, man")+')) - self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , man with a', 1.1), ('beard', 1.1*1.1)]), parse_prompt('hairy ("mountain, man" with a beard+)+')) - self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , man with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, man" with a (beard)2.0)+')) - self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , \"man\" with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\"man\\"" with a (beard)2.0)+')) - self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , m\"an\" with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, m\\"an\\"" with a (beard)2.0)+')) - - self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , \"man (with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" \(with a (beard)2.0)+')) - self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , \"man w(ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" w\(ith a (beard)2.0)+')) - self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , \"man with( a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" with\( a (beard)2.0)+')) - self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , \"man )with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" \)with a (beard)2.0)+')) - self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , \"man w)ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" w\)ith a (beard)2.0)+')) - self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , \"man with) a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" with\) a (beard)2.0)+')) - self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mou)ntain , \"man (wit(h a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mou\)ntain, \\\"man\" \(wit\(h a (beard)2.0)+')) - self.assertEqual(make_weighted_conjunction([('hai(ry', 1), ('mountain , \"man w)ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hai\(ry ("mountain, \\\"man\" w\)ith a (beard)2.0)+')) - self.assertEqual(make_weighted_conjunction([('hairy((', 1), ('mountain , \"man with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy\(\( ("mountain, \\\"man\" with a (beard)2.0)+')) - - self.assertEqual(make_weighted_conjunction([('mountain , \"man (with a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" \(with a (beard)2.0)+ hairy')) - self.assertEqual(make_weighted_conjunction([('mountain , \"man w(ith a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" w\(ith a (beard)2.0)+hairy')) - self.assertEqual(make_weighted_conjunction([('mountain , \"man with( a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" with\( a (beard)2.0)+ hairy')) - self.assertEqual(make_weighted_conjunction([('mountain , \"man )with a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" \)with a (beard)2.0)+ hairy')) - self.assertEqual(make_weighted_conjunction([('mountain , \"man w)ith a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" w\)ith a (beard)2.0)+ hairy')) - self.assertEqual(make_weighted_conjunction([('mountain , \"man with) a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt(' ("mountain, \\\"man\" with\) a (beard)2.0)+ hairy')) - self.assertEqual(make_weighted_conjunction([('mou)ntain , \"man (wit(h a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mou\)ntain, \\\"man\" \(wit\(h a (beard)2.0)+ hairy')) - self.assertEqual(make_weighted_conjunction([('mountain , \"man w)ith a', 1.1), ('beard', 1.1*2.0), ('hai(ry', 1)]), parse_prompt('("mountain, \\\"man\" w\)ith a (beard)2.0)+ hai\(ry ')) - self.assertEqual(make_weighted_conjunction([('mountain , \"man with a', 1.1), ('beard', 1.1*2.0), ('hairy((', 1)]), parse_prompt('("mountain, \\\"man\" with a (beard)2.0)+ hairy\(\( ')) - - def test_cross_attention_escaping(self): - - self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('monkey', 1)])])]), - parse_prompt('mountain (man).swap(monkey)')) - self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('m(onkey', 1)])])]), - parse_prompt('mountain (man).swap(m\(onkey)')) - self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('m(an', 1)], [Fragment('m(onkey', 1)])])]), - parse_prompt('mountain (m\(an).swap(m\(onkey)')) - self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('(((', 1)], [Fragment('m(on))key', 1)])])]), - parse_prompt('mountain (\(\(\().swap(m\(on\)\)key)')) - - self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('monkey', 1)])])]), - parse_prompt('mountain ("man").swap(monkey)')) - self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('monkey', 1)])])]), - parse_prompt('mountain ("man").swap("monkey")')) - self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('"man', 1)], [Fragment('monkey', 1)])])]), - parse_prompt('mountain (\\"man).swap("monkey")')) - self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('m(onkey', 1)])])]), - parse_prompt('mountain (man).swap(m\(onkey)')) - self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('m(an', 1)], [Fragment('m(onkey', 1)])])]), - parse_prompt('mountain (m\(an).swap(m\(onkey)')) - self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('(((', 1)], [Fragment('m(on))key', 1)])])]), - parse_prompt('mountain (\(\(\().swap(m\(on\)\)key)')) - - def test_legacy_blend(self): - pp = PromptParser() - - self.assertEqual(Blend([FlattenedPrompt([('mountain man', 1)]), - FlattenedPrompt([('man mountain', 1)])], - weights=[0.5,0.5]), - pp.parse_legacy_blend('mountain man:1 man mountain:1')) - - self.assertEqual(Blend([FlattenedPrompt([('mountain', 1.1), ('man', 1)]), - FlattenedPrompt([('man', 1), ('mountain', 0.9)])], - weights=[0.5,0.5]), - pp.parse_legacy_blend('mountain+ man:1 man mountain-:1')) - - self.assertEqual(Blend([FlattenedPrompt([('mountain', 1.1), ('man', 1)]), - FlattenedPrompt([('man', 1), ('mountain', 0.9)])], - weights=[0.5,0.5]), - pp.parse_legacy_blend('mountain+ man:1 man mountain-')) - - self.assertEqual(Blend([FlattenedPrompt([('mountain', 1.1), ('man', 1)]), - FlattenedPrompt([('man', 1), ('mountain', 0.9)])], - weights=[0.5,0.5]), - pp.parse_legacy_blend('mountain+ man: man mountain-:')) - - self.assertEqual(Blend([FlattenedPrompt([('mountain man', 1)]), - FlattenedPrompt([('man mountain', 1)])], - weights=[0.75,0.25]), - pp.parse_legacy_blend('mountain man:3 man mountain:1')) - - self.assertEqual(Blend([FlattenedPrompt([('mountain man', 1)]), - FlattenedPrompt([('man mountain', 1)])], - weights=[1.0,0.0]), - pp.parse_legacy_blend('mountain man:3 man mountain:0')) - - self.assertEqual(Blend([FlattenedPrompt([('mountain man', 1)]), - FlattenedPrompt([('man mountain', 1)])], - weights=[0.8,0.2]), - pp.parse_legacy_blend('"mountain man":4 man mountain')) - - - def test_single(self): - self.assertEqual(Conjunction([FlattenedPrompt([("mountain man", 1.0)]), - FlattenedPrompt([("a person with a hat", 1.0), - ("riding a", 1.1*1.1), - CrossAttentionControlSubstitute( - [Fragment("bicycle", pow(1.1,2))], - [Fragment("skateboard", pow(1.1,2))]) - ]) - ], weights=[0.5, 0.5]), - parse_prompt("(\"mountain man\", \"a person with a hat (riding a bicycle.swap(skateboard))++\").and(0.5, 0.5)")) - pass - - -if __name__ == '__main__': - unittest.main()