mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
move all prompting stuff to use compel
This commit is contained in:
parent
b9ecf93ba3
commit
ded3f13a33
@ -30,7 +30,7 @@ from ldm.invoke.generator.diffusers_pipeline import PipelineIntermediateState
|
|||||||
from ldm.invoke.generator.inpaint import infill_methods
|
from ldm.invoke.generator.inpaint import infill_methods
|
||||||
from ldm.invoke.globals import Globals, global_converted_ckpts_dir
|
from ldm.invoke.globals import Globals, global_converted_ckpts_dir
|
||||||
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata
|
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata
|
||||||
from ldm.invoke.prompt_parser import split_weighted_subprompts, Blend
|
from compel.prompt_parser import split_weighted_subprompts, Blend
|
||||||
from ldm.invoke.globals import global_models_dir
|
from ldm.invoke.globals import global_models_dir
|
||||||
from ldm.invoke.merge_diffusers import merge_diffusion_models
|
from ldm.invoke.merge_diffusers import merge_diffusion_models
|
||||||
|
|
||||||
|
@ -9,6 +9,8 @@ from typing import List, Optional, Union
|
|||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
|
from compel import PromptParser
|
||||||
|
|
||||||
if sys.platform == "darwin":
|
if sys.platform == "darwin":
|
||||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||||
|
|
||||||
@ -23,7 +25,6 @@ from ldm.invoke.image_util import make_grid
|
|||||||
from ldm.invoke.log import write_log
|
from ldm.invoke.log import write_log
|
||||||
from ldm.invoke.model_manager import ModelManager
|
from ldm.invoke.model_manager import ModelManager
|
||||||
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata
|
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata
|
||||||
from ldm.invoke.prompt_parser import PromptParser
|
|
||||||
from ldm.invoke.readline import Completer, get_completer
|
from ldm.invoke.readline import Completer, get_completer
|
||||||
from ldm.util import url_attachment_name
|
from ldm.util import url_attachment_name
|
||||||
|
|
||||||
|
@ -93,9 +93,9 @@ import shlex
|
|||||||
import sys
|
import sys
|
||||||
import ldm.invoke
|
import ldm.invoke
|
||||||
import ldm.invoke.pngwriter
|
import ldm.invoke.pngwriter
|
||||||
|
from compel.prompt_parser import split_weighted_subprompts
|
||||||
|
|
||||||
from ldm.invoke.globals import Globals
|
from ldm.invoke.globals import Globals
|
||||||
from ldm.invoke.prompt_parser import split_weighted_subprompts
|
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
@ -9,59 +9,75 @@ get_uc_and_c_and_ec() get the conditioned and unconditioned latent, an
|
|||||||
import re
|
import re
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import torch
|
from compel import Compel
|
||||||
|
from compel.prompt_parser import FlattenedPrompt, Blend, Fragment, CrossAttentionControlSubstitute
|
||||||
from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \
|
from .devices import torch_dtype
|
||||||
CrossAttentionControlledFragment, CrossAttentionControlSubstitute, Fragment
|
|
||||||
from ..models.diffusion import cross_attention_control
|
|
||||||
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||||
from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder
|
|
||||||
from ..modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter
|
|
||||||
from ldm.invoke.globals import Globals
|
from ldm.invoke.globals import Globals
|
||||||
|
|
||||||
|
|
||||||
def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False):
|
def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False):
|
||||||
|
|
||||||
# lazy-load any deferred textual inversions.
|
# lazy-load any deferred textual inversions.
|
||||||
# this might take a couple of seconds the first time a textual inversion is used.
|
# this might take a couple of seconds the first time a textual inversion is used.
|
||||||
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(prompt_string)
|
model.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms(prompt_string)
|
||||||
|
|
||||||
prompt, negative_prompt = get_prompt_structure(prompt_string,
|
compel = Compel(tokenizer=model.tokenizer,
|
||||||
skip_normalize_legacy_blend=skip_normalize_legacy_blend)
|
text_encoder=model.text_encoder,
|
||||||
conditioning = _get_conditioning_for_prompt(prompt, negative_prompt, model, log_tokens)
|
textual_inversion_manager=model.textual_inversion_manager,
|
||||||
|
dtype_for_device_getter=torch_dtype)
|
||||||
|
|
||||||
return conditioning
|
positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string)
|
||||||
|
positive_prompt = compel.parse_prompt_string(positive_prompt_string)
|
||||||
|
negative_prompt = compel.parse_prompt_string(negative_prompt_string)
|
||||||
|
|
||||||
|
if log_tokens or getattr(Globals, "log_tokenization", False):
|
||||||
|
log_tokenization(positive_prompt, negative_prompt, tokenizer=model.tokenizer)
|
||||||
|
|
||||||
|
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
|
||||||
|
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)
|
||||||
|
|
||||||
|
tokens_count = get_tokens_for_prompt(tokenizer=model.tokenizer, parsed_prompt=positive_prompt)
|
||||||
|
|
||||||
|
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(tokens_count_including_eos_bos=tokens_count,
|
||||||
|
cross_attention_control_args=options.get(
|
||||||
|
'cross_attention_control', None))
|
||||||
|
return uc, c, ec
|
||||||
|
|
||||||
|
|
||||||
def get_prompt_structure(prompt_string, skip_normalize_legacy_blend: bool = False) -> (
|
def get_prompt_structure(prompt_string, model, skip_normalize_legacy_blend: bool = False) -> (
|
||||||
Union[FlattenedPrompt, Blend], FlattenedPrompt):
|
Union[FlattenedPrompt, Blend], FlattenedPrompt):
|
||||||
"""
|
"""
|
||||||
parse the passed-in prompt string and return tuple (positive_prompt, negative_prompt)
|
parse the passed-in prompt string and return tuple (positive_prompt, negative_prompt)
|
||||||
"""
|
"""
|
||||||
prompt, negative_prompt = _parse_prompt_string(prompt_string,
|
compel = Compel(tokenizer=model.tokenizer,
|
||||||
skip_normalize_legacy_blend=skip_normalize_legacy_blend)
|
text_encoder=model.text_encoder,
|
||||||
return prompt, negative_prompt
|
textual_inversion_manager=model.textual_inversion_manager,
|
||||||
|
dtype_for_device_getter=torch_dtype)
|
||||||
|
|
||||||
|
positive_prompt_string, negative_prompt_string = split_prompt_to_positive_and_negative(prompt_string)
|
||||||
|
positive_prompt = compel.parse_prompt_string(positive_prompt_string)
|
||||||
|
negative_prompt = compel.parse_prompt_string(negative_prompt_string)
|
||||||
|
|
||||||
|
return positive_prompt, negative_prompt
|
||||||
|
|
||||||
|
|
||||||
def get_tokens_for_prompt(model, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True) -> [str]:
|
def get_tokens_for_prompt(tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True) -> [str]:
|
||||||
text_fragments = [x.text if type(x) is Fragment else
|
text_fragments = [x.text if type(x) is Fragment else
|
||||||
(" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else
|
(" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else
|
||||||
str(x))
|
str(x))
|
||||||
for x in parsed_prompt.children]
|
for x in parsed_prompt.children]
|
||||||
text = " ".join(text_fragments)
|
text = " ".join(text_fragments)
|
||||||
tokens = model.cond_stage_model.tokenizer.tokenize(text)
|
tokens = tokenizer.tokenize(text)
|
||||||
if truncate_if_too_long:
|
if truncate_if_too_long:
|
||||||
max_tokens_length = model.cond_stage_model.max_length - 2 # typically 75
|
max_tokens_length = tokenizer.model_max_length - 2 # typically 75
|
||||||
tokens = tokens[0:max_tokens_length]
|
tokens = tokens[0:max_tokens_length]
|
||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
def _parse_prompt_string(prompt_string_uncleaned, skip_normalize_legacy_blend=False) -> Union[FlattenedPrompt, Blend]:
|
def split_prompt_to_positive_and_negative(prompt_string_uncleaned):
|
||||||
# Extract Unconditioned Words From Prompt
|
|
||||||
unconditioned_words = ''
|
unconditioned_words = ''
|
||||||
unconditional_regex = r'\[(.*?)\]'
|
unconditional_regex = r'\[(.*?)\]'
|
||||||
unconditionals = re.findall(unconditional_regex, prompt_string_uncleaned)
|
unconditionals = re.findall(unconditional_regex, prompt_string_uncleaned)
|
||||||
|
|
||||||
if len(unconditionals) > 0:
|
if len(unconditionals) > 0:
|
||||||
unconditioned_words = ' '.join(unconditionals)
|
unconditioned_words = ' '.join(unconditionals)
|
||||||
|
|
||||||
@ -71,210 +87,57 @@ def _parse_prompt_string(prompt_string_uncleaned, skip_normalize_legacy_blend=Fa
|
|||||||
prompt_string_cleaned = re.sub(' +', ' ', clean_prompt)
|
prompt_string_cleaned = re.sub(' +', ' ', clean_prompt)
|
||||||
else:
|
else:
|
||||||
prompt_string_cleaned = prompt_string_uncleaned
|
prompt_string_cleaned = prompt_string_uncleaned
|
||||||
|
return prompt_string_cleaned, unconditioned_words
|
||||||
pp = PromptParser()
|
|
||||||
|
|
||||||
parsed_prompt: Union[FlattenedPrompt, Blend] = None
|
|
||||||
legacy_blend: Blend = pp.parse_legacy_blend(prompt_string_cleaned, skip_normalize_legacy_blend)
|
|
||||||
if legacy_blend is not None:
|
|
||||||
parsed_prompt = legacy_blend
|
|
||||||
else:
|
|
||||||
# we don't support conjunctions for now
|
|
||||||
parsed_prompt = pp.parse_conjunction(prompt_string_cleaned).prompts[0]
|
|
||||||
|
|
||||||
parsed_negative_prompt: FlattenedPrompt = pp.parse_conjunction(unconditioned_words).prompts[0]
|
|
||||||
return parsed_prompt, parsed_negative_prompt
|
|
||||||
|
|
||||||
|
|
||||||
def _get_conditioning_for_prompt(parsed_prompt: Union[Blend, FlattenedPrompt], parsed_negative_prompt: FlattenedPrompt,
|
def log_tokenization(positive_prompt: Blend | FlattenedPrompt,
|
||||||
model, log_tokens=False) \
|
negative_prompt: Blend | FlattenedPrompt,
|
||||||
-> tuple[torch.Tensor, torch.Tensor, InvokeAIDiffuserComponent.ExtraConditioningInfo]:
|
tokenizer):
|
||||||
"""
|
print(f"\n>> [TOKENLOG] Parsed Prompt: {positive_prompt}")
|
||||||
Process prompt structure and tokens, and return (conditioning, unconditioning, extra_conditioning_info)
|
print(f"\n>> [TOKENLOG] Parsed Negative Prompt: {negative_prompt}")
|
||||||
"""
|
|
||||||
|
|
||||||
if log_tokens or getattr(Globals, "log_tokenization", False):
|
log_tokenization_for_prompt_object(positive_prompt, tokenizer)
|
||||||
print(f"\n>> [TOKENLOG] Parsed Prompt: {parsed_prompt}")
|
log_tokenization_for_prompt_object(negative_prompt, tokenizer, display_label_prefix="(negative prompt)")
|
||||||
print(f"\n>> [TOKENLOG] Parsed Negative Prompt: {parsed_negative_prompt}")
|
|
||||||
|
|
||||||
conditioning = None
|
|
||||||
cac_args: cross_attention_control.Arguments = None
|
|
||||||
|
|
||||||
if type(parsed_prompt) is Blend:
|
def log_tokenization_for_prompt_object(p: Blend | FlattenedPrompt, tokenizer, display_label_prefix=None):
|
||||||
conditioning = _get_conditioning_for_blend(model, parsed_prompt, log_tokens)
|
display_label_prefix = display_label_prefix or ""
|
||||||
elif type(parsed_prompt) is FlattenedPrompt:
|
if type(p) is Blend:
|
||||||
if parsed_prompt.wants_cross_attention_control:
|
blend: Blend = p
|
||||||
conditioning, cac_args = _get_conditioning_for_cross_attention_control(model, parsed_prompt, log_tokens)
|
for i, c in enumerate(blend.prompts):
|
||||||
|
log_tokenization_for_prompt_object(
|
||||||
|
c, tokenizer,
|
||||||
|
display_label_prefix=f"{display_label_prefix}(blend part {i + 1}, weight={blend.weights[i]})")
|
||||||
|
elif type(p) is FlattenedPrompt:
|
||||||
|
flattened_prompt: FlattenedPrompt = p
|
||||||
|
if flattened_prompt.wants_cross_attention_control:
|
||||||
|
original_fragments = []
|
||||||
|
edited_fragments = []
|
||||||
|
for f in flattened_prompt.children:
|
||||||
|
if type(f) is CrossAttentionControlSubstitute:
|
||||||
|
original_fragments += f.original
|
||||||
|
edited_fragments += f.edited
|
||||||
|
else:
|
||||||
|
original_fragments.append(f)
|
||||||
|
edited_fragments.append(f)
|
||||||
|
|
||||||
|
original_text = " ".join([x.text for x in original_fragments])
|
||||||
|
log_tokenization_for_text(original_text, tokenizer,
|
||||||
|
display_label=f"{display_label_prefix}(.swap originals)")
|
||||||
|
edited_text = " ".join([x.text for x in edited_fragments])
|
||||||
|
log_tokenization_for_text(edited_text, tokenizer,
|
||||||
|
display_label=f"{display_label_prefix}(.swap replacements)")
|
||||||
else:
|
else:
|
||||||
conditioning, _ = _get_embeddings_and_tokens_for_prompt(model,
|
text = " ".join([x.text for x in flattened_prompt.children])
|
||||||
parsed_prompt,
|
log_tokenization_for_text(text, tokenizer, display_label=display_label_prefix)
|
||||||
log_tokens=log_tokens,
|
|
||||||
log_display_label="(prompt)")
|
|
||||||
else:
|
|
||||||
raise ValueError(f"parsed_prompt is '{type(parsed_prompt)}' which is not a supported prompt type")
|
|
||||||
|
|
||||||
unconditioning, _ = _get_embeddings_and_tokens_for_prompt(model,
|
|
||||||
parsed_negative_prompt,
|
|
||||||
log_tokens=log_tokens,
|
|
||||||
log_display_label="(unconditioning)")
|
|
||||||
if isinstance(conditioning, dict):
|
|
||||||
# hybrid conditioning is in play
|
|
||||||
unconditioning, conditioning = _flatten_hybrid_conditioning(unconditioning, conditioning)
|
|
||||||
if cac_args is not None:
|
|
||||||
print(
|
|
||||||
">> Hybrid conditioning cannot currently be combined with cross attention control. Cross attention control will be ignored.")
|
|
||||||
cac_args = None
|
|
||||||
|
|
||||||
if type(parsed_prompt) is Blend:
|
|
||||||
blend: Blend = parsed_prompt
|
|
||||||
all_token_sequences = [get_tokens_for_prompt(model, p) for p in blend.prompts]
|
|
||||||
longest_token_sequence = max(all_token_sequences, key=lambda t: len(t))
|
|
||||||
eos_token_index = len(longest_token_sequence)+1
|
|
||||||
else:
|
|
||||||
tokens = get_tokens_for_prompt(model, parsed_prompt)
|
|
||||||
eos_token_index = len(tokens)+1
|
|
||||||
return (
|
|
||||||
unconditioning, conditioning, InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
|
||||||
tokens_count_including_eos_bos=eos_token_index + 1,
|
|
||||||
cross_attention_control_args=cac_args
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_conditioning_for_cross_attention_control(model, prompt: FlattenedPrompt, log_tokens: bool = True):
|
def log_tokenization_for_text(text, tokenizer, display_label=None):
|
||||||
original_prompt = FlattenedPrompt()
|
|
||||||
edited_prompt = FlattenedPrompt()
|
|
||||||
# for name, a0, a1, b0, b1 in edit_opcodes: only name == 'equal' is currently parsed
|
|
||||||
original_token_count = 0
|
|
||||||
edited_token_count = 0
|
|
||||||
edit_options = []
|
|
||||||
edit_opcodes = []
|
|
||||||
# beginning of sequence
|
|
||||||
edit_opcodes.append(
|
|
||||||
('equal', original_token_count, original_token_count + 1, edited_token_count, edited_token_count + 1))
|
|
||||||
edit_options.append(None)
|
|
||||||
original_token_count += 1
|
|
||||||
edited_token_count += 1
|
|
||||||
for fragment in prompt.children:
|
|
||||||
if type(fragment) is CrossAttentionControlSubstitute:
|
|
||||||
original_prompt.append(fragment.original)
|
|
||||||
edited_prompt.append(fragment.edited)
|
|
||||||
|
|
||||||
to_replace_token_count = _get_tokens_length(model, fragment.original)
|
|
||||||
replacement_token_count = _get_tokens_length(model, fragment.edited)
|
|
||||||
edit_opcodes.append(('replace',
|
|
||||||
original_token_count, original_token_count + to_replace_token_count,
|
|
||||||
edited_token_count, edited_token_count + replacement_token_count
|
|
||||||
))
|
|
||||||
original_token_count += to_replace_token_count
|
|
||||||
edited_token_count += replacement_token_count
|
|
||||||
edit_options.append(fragment.options)
|
|
||||||
# elif type(fragment) is CrossAttentionControlAppend:
|
|
||||||
# edited_prompt.append(fragment.fragment)
|
|
||||||
else:
|
|
||||||
# regular fragment
|
|
||||||
original_prompt.append(fragment)
|
|
||||||
edited_prompt.append(fragment)
|
|
||||||
|
|
||||||
count = _get_tokens_length(model, [fragment])
|
|
||||||
edit_opcodes.append(('equal', original_token_count, original_token_count + count, edited_token_count,
|
|
||||||
edited_token_count + count))
|
|
||||||
edit_options.append(None)
|
|
||||||
original_token_count += count
|
|
||||||
edited_token_count += count
|
|
||||||
# end of sequence
|
|
||||||
edit_opcodes.append(
|
|
||||||
('equal', original_token_count, original_token_count + 1, edited_token_count, edited_token_count + 1))
|
|
||||||
edit_options.append(None)
|
|
||||||
original_token_count += 1
|
|
||||||
edited_token_count += 1
|
|
||||||
original_embeddings, original_tokens = _get_embeddings_and_tokens_for_prompt(model,
|
|
||||||
original_prompt,
|
|
||||||
log_tokens=log_tokens,
|
|
||||||
log_display_label="(.swap originals)")
|
|
||||||
# naïvely building a single edited_embeddings like this disregards the effects of changing the absolute location of
|
|
||||||
# subsequent tokens when there is >1 edit and earlier edits change the total token count.
|
|
||||||
# eg "a cat.swap(smiling dog, s_start=0.5) eating a hotdog.swap(pizza)" - when the 'pizza' edit is active but the
|
|
||||||
# 'cat' edit is not, the 'pizza' feature vector will nevertheless be affected by the introduction of the extra
|
|
||||||
# token 'smiling' in the inactive 'cat' edit.
|
|
||||||
# todo: build multiple edited_embeddings, one for each edit, and pass just the edited fragments through to the CrossAttentionControl functions
|
|
||||||
edited_embeddings, edited_tokens = _get_embeddings_and_tokens_for_prompt(model,
|
|
||||||
edited_prompt,
|
|
||||||
log_tokens=log_tokens,
|
|
||||||
log_display_label="(.swap replacements)")
|
|
||||||
conditioning = original_embeddings
|
|
||||||
edited_conditioning = edited_embeddings
|
|
||||||
# print('>> got edit_opcodes', edit_opcodes, 'options', edit_options)
|
|
||||||
cac_args = cross_attention_control.Arguments(
|
|
||||||
edited_conditioning=edited_conditioning,
|
|
||||||
edit_opcodes=edit_opcodes,
|
|
||||||
edit_options=edit_options
|
|
||||||
)
|
|
||||||
return conditioning, cac_args
|
|
||||||
|
|
||||||
|
|
||||||
def _get_conditioning_for_blend(model, blend: Blend, log_tokens: bool = False):
|
|
||||||
embeddings_to_blend = None
|
|
||||||
for i, flattened_prompt in enumerate(blend.prompts):
|
|
||||||
this_embedding, _ = _get_embeddings_and_tokens_for_prompt(model,
|
|
||||||
flattened_prompt,
|
|
||||||
log_tokens=log_tokens,
|
|
||||||
log_display_label=f"(blend part {i + 1}, weight={blend.weights[i]})")
|
|
||||||
embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat(
|
|
||||||
(embeddings_to_blend, this_embedding))
|
|
||||||
conditioning = WeightedPromptFragmentsToEmbeddingsConverter.apply_embedding_weights(embeddings_to_blend.unsqueeze(0),
|
|
||||||
blend.weights,
|
|
||||||
normalize=blend.normalize_weights)
|
|
||||||
return conditioning
|
|
||||||
|
|
||||||
|
|
||||||
def _get_embeddings_and_tokens_for_prompt(model, flattened_prompt: FlattenedPrompt, log_tokens: bool = False,
|
|
||||||
log_display_label: str = None):
|
|
||||||
if type(flattened_prompt) is not FlattenedPrompt:
|
|
||||||
raise Exception(f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead")
|
|
||||||
fragments = [x.text for x in flattened_prompt.children]
|
|
||||||
weights = [x.weight for x in flattened_prompt.children]
|
|
||||||
embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True, fragment_weights=[weights])
|
|
||||||
if log_tokens or getattr(Globals, "log_tokenization", False):
|
|
||||||
text = " ".join(fragments)
|
|
||||||
log_tokenization(text, model, display_label=log_display_label)
|
|
||||||
|
|
||||||
return embeddings, tokens
|
|
||||||
|
|
||||||
|
|
||||||
def _get_tokens_length(model, fragments: list[Fragment]):
|
|
||||||
fragment_texts = [x.text for x in fragments]
|
|
||||||
tokens = model.cond_stage_model.get_token_ids(fragment_texts, include_start_and_end_markers=False)
|
|
||||||
return sum([len(x) for x in tokens])
|
|
||||||
|
|
||||||
|
|
||||||
def _flatten_hybrid_conditioning(uncond, cond):
|
|
||||||
'''
|
|
||||||
This handles the choice between a conditional conditioning
|
|
||||||
that is a tensor (used by cross attention) vs one that has additional
|
|
||||||
dimensions as well, as used by 'hybrid'
|
|
||||||
'''
|
|
||||||
assert isinstance(uncond, dict)
|
|
||||||
assert isinstance(cond, dict)
|
|
||||||
cond_flattened = dict()
|
|
||||||
for k in cond:
|
|
||||||
if isinstance(cond[k], list):
|
|
||||||
cond_flattened[k] = [
|
|
||||||
torch.cat([uncond[k][i], cond[k][i]])
|
|
||||||
for i in range(len(cond[k]))
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
cond_flattened[k] = torch.cat([uncond[k], cond[k]])
|
|
||||||
return uncond, cond_flattened
|
|
||||||
|
|
||||||
|
|
||||||
def log_tokenization(text, model, display_label=None):
|
|
||||||
""" shows how the prompt is tokenized
|
""" shows how the prompt is tokenized
|
||||||
# usually tokens have '</w>' to indicate end-of-word,
|
# usually tokens have '</w>' to indicate end-of-word,
|
||||||
# but for readability it has been replaced with ' '
|
# but for readability it has been replaced with ' '
|
||||||
"""
|
"""
|
||||||
tokens = model.cond_stage_model.tokenizer.tokenize(text)
|
tokens = tokenizer.tokenize(text)
|
||||||
tokenized = ""
|
tokenized = ""
|
||||||
discarded = ""
|
discarded = ""
|
||||||
usedTokens = 0
|
usedTokens = 0
|
||||||
@ -284,7 +147,7 @@ def log_tokenization(text, model, display_label=None):
|
|||||||
token = tokens[i].replace('</w>', ' ')
|
token = tokens[i].replace('</w>', ' ')
|
||||||
# alternate color
|
# alternate color
|
||||||
s = (usedTokens % 6) + 1
|
s = (usedTokens % 6) + 1
|
||||||
if i < model.cond_stage_model.max_length:
|
if i < tokenizer.model_max_length:
|
||||||
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
|
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
|
||||||
usedTokens += 1
|
usedTokens += 1
|
||||||
else: # over max token length
|
else: # over max token length
|
||||||
|
@ -31,7 +31,7 @@ from ldm.modules.textual_inversion_manager import TextualInversionManager
|
|||||||
from ..devices import normalize_device, CPU_DEVICE
|
from ..devices import normalize_device, CPU_DEVICE
|
||||||
from ..offloading import LazilyLoadedModelGroup, FullyLoadedModelGroup, ModelGroup
|
from ..offloading import LazilyLoadedModelGroup, FullyLoadedModelGroup, ModelGroup
|
||||||
from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||||
from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter
|
from compel import EmbeddingsProvider
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -294,7 +294,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
text_encoder=self.text_encoder,
|
text_encoder=self.text_encoder,
|
||||||
full_precision=use_full_precision)
|
full_precision=use_full_precision)
|
||||||
# InvokeAI's interface for text embeddings and whatnot
|
# InvokeAI's interface for text embeddings and whatnot
|
||||||
self.prompt_fragments_to_embeddings_converter = WeightedPromptFragmentsToEmbeddingsConverter(
|
self.embeddings_provider = EmbeddingsProvider(
|
||||||
tokenizer=self.tokenizer,
|
tokenizer=self.tokenizer,
|
||||||
text_encoder=self.text_encoder,
|
text_encoder=self.text_encoder,
|
||||||
textual_inversion_manager=self.textual_inversion_manager
|
textual_inversion_manager=self.textual_inversion_manager
|
||||||
@ -726,15 +726,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
"""
|
"""
|
||||||
Compatibility function for ldm.models.diffusion.ddpm.LatentDiffusion.
|
Compatibility function for ldm.models.diffusion.ddpm.LatentDiffusion.
|
||||||
"""
|
"""
|
||||||
return self.prompt_fragments_to_embeddings_converter.get_embeddings_for_weighted_prompt_fragments(
|
return self.embeddings_provider.get_embeddings_for_weighted_prompt_fragments(
|
||||||
text=c,
|
text_batch=c,
|
||||||
fragment_weights=fragment_weights,
|
fragment_weights_batch=fragment_weights,
|
||||||
should_return_tokens=return_tokens,
|
should_return_tokens=return_tokens,
|
||||||
device=self._model_group.device_for(self.unet))
|
device=self._model_group.device_for(self.unet))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cond_stage_model(self):
|
def cond_stage_model(self):
|
||||||
return self.prompt_fragments_to_embeddings_converter
|
return self.embeddings_provider
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def _tokenize(self, prompt: Union[str, List[str]]):
|
def _tokenize(self, prompt: Union[str, List[str]]):
|
||||||
|
@ -1,655 +0,0 @@
|
|||||||
import string
|
|
||||||
from typing import Union, Optional
|
|
||||||
import re
|
|
||||||
import pyparsing as pp
|
|
||||||
'''
|
|
||||||
This module parses prompt strings and produces tree-like structures that can be used generate and control the conditioning tensors.
|
|
||||||
weighted subprompts.
|
|
||||||
|
|
||||||
Useful class exports:
|
|
||||||
|
|
||||||
PromptParser - parses prompts
|
|
||||||
|
|
||||||
Useful function exports:
|
|
||||||
|
|
||||||
split_weighted_subpromopts() split subprompts, normalize and weight them
|
|
||||||
log_tokenization() print out colour-coded tokens and warn if truncated
|
|
||||||
'''
|
|
||||||
|
|
||||||
class Prompt():
|
|
||||||
"""
|
|
||||||
Mid-level structure for storing the tree-like result of parsing a prompt. A Prompt may not represent the whole of
|
|
||||||
the singular user-defined "prompt string" (although it can) - for example, if the user specifies a Blend, the objects
|
|
||||||
that are to be blended together are stored individuall as Prompt objects.
|
|
||||||
|
|
||||||
Nesting makes this object not suitable for directly tokenizing; instead call flatten() on the containing Conjunction
|
|
||||||
to produce a FlattenedPrompt.
|
|
||||||
"""
|
|
||||||
def __init__(self, parts: list):
|
|
||||||
for c in parts:
|
|
||||||
if type(c) is not Attention and not issubclass(type(c), BaseFragment) and type(c) is not pp.ParseResults:
|
|
||||||
raise PromptParser.ParsingException(f"Prompt cannot contain {type(c).__name__} ({c}), only {[c.__name__ for c in BaseFragment.__subclasses__()]} are allowed")
|
|
||||||
self.children = parts
|
|
||||||
def __repr__(self):
|
|
||||||
return f"Prompt:{self.children}"
|
|
||||||
def __eq__(self, other):
|
|
||||||
return type(other) is Prompt and other.children == self.children
|
|
||||||
|
|
||||||
class BaseFragment:
|
|
||||||
pass
|
|
||||||
|
|
||||||
class FlattenedPrompt():
|
|
||||||
"""
|
|
||||||
A Prompt that has been passed through flatten(). Its children can be readily tokenized.
|
|
||||||
"""
|
|
||||||
def __init__(self, parts: list=[]):
|
|
||||||
self.children = []
|
|
||||||
for part in parts:
|
|
||||||
self.append(part)
|
|
||||||
|
|
||||||
def append(self, fragment: Union[list, BaseFragment, tuple]):
|
|
||||||
# verify type correctness
|
|
||||||
if type(fragment) is list:
|
|
||||||
for x in fragment:
|
|
||||||
self.append(x)
|
|
||||||
elif issubclass(type(fragment), BaseFragment):
|
|
||||||
self.children.append(fragment)
|
|
||||||
elif type(fragment) is tuple:
|
|
||||||
# upgrade tuples to Fragments
|
|
||||||
if type(fragment[0]) is not str or (type(fragment[1]) is not float and type(fragment[1]) is not int):
|
|
||||||
raise PromptParser.ParsingException(
|
|
||||||
f"FlattenedPrompt cannot contain {fragment}, only Fragments or (str, float) tuples are allowed")
|
|
||||||
self.children.append(Fragment(fragment[0], fragment[1]))
|
|
||||||
else:
|
|
||||||
raise PromptParser.ParsingException(
|
|
||||||
f"FlattenedPrompt cannot contain {fragment}, only Fragments or (str, float) tuples are allowed")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_empty(self):
|
|
||||||
return len(self.children) == 0 or \
|
|
||||||
(len(self.children) == 1 and len(self.children[0].text) == 0)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def wants_cross_attention_control(self):
|
|
||||||
return any(
|
|
||||||
[issubclass(type(x), CrossAttentionControlledFragment) for x in self.children]
|
|
||||||
)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"FlattenedPrompt:{self.children}"
|
|
||||||
def __eq__(self, other):
|
|
||||||
return type(other) is FlattenedPrompt and other.children == self.children
|
|
||||||
|
|
||||||
|
|
||||||
class Fragment(BaseFragment):
|
|
||||||
"""
|
|
||||||
A Fragment is a chunk of plain text and an optional weight. The text should be passed as-is to the CLIP tokenizer.
|
|
||||||
"""
|
|
||||||
def __init__(self, text: str, weight: float=1):
|
|
||||||
assert(type(text) is str)
|
|
||||||
if '\\"' in text or '\\(' in text or '\\)' in text:
|
|
||||||
#print("Fragment converting escaped \( \) \\\" into ( ) \"")
|
|
||||||
text = text.replace('\\(', '(').replace('\\)', ')').replace('\\"', '"')
|
|
||||||
self.text = text
|
|
||||||
self.weight = float(weight)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return "Fragment:'"+self.text+"'@"+str(self.weight)
|
|
||||||
def __eq__(self, other):
|
|
||||||
return type(other) is Fragment \
|
|
||||||
and other.text == self.text \
|
|
||||||
and other.weight == self.weight
|
|
||||||
|
|
||||||
class Attention():
|
|
||||||
"""
|
|
||||||
Nestable weight control for fragments. Each object in the children array may in turn be an Attention object;
|
|
||||||
weights should be considered to accumulate as the tree is traversed to deeper levels of nesting.
|
|
||||||
|
|
||||||
Do not traverse directly; instead obtain a FlattenedPrompt by calling Flatten() on a top-level Conjunction object.
|
|
||||||
"""
|
|
||||||
def __init__(self, weight: float, children: list):
|
|
||||||
if type(weight) is not float:
|
|
||||||
raise PromptParser.ParsingException(
|
|
||||||
f"Attention weight must be float (got {type(weight).__name__} {weight})")
|
|
||||||
self.weight = weight
|
|
||||||
if type(children) is not list:
|
|
||||||
raise PromptParser.ParsingException(f"cannot make Attention with non-list of children (got {type(children)})")
|
|
||||||
assert(type(children) is list)
|
|
||||||
self.children = children
|
|
||||||
#print(f"A: requested attention '{children}' to {weight}")
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"Attention:{self.children} * {self.weight}"
|
|
||||||
def __eq__(self, other):
|
|
||||||
return type(other) is Attention and other.weight == self.weight and other.fragment == self.fragment
|
|
||||||
|
|
||||||
class CrossAttentionControlledFragment(BaseFragment):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class CrossAttentionControlSubstitute(CrossAttentionControlledFragment):
|
|
||||||
"""
|
|
||||||
A Cross-Attention Controlled ('prompt2prompt') fragment, for use inside a Prompt, Attention, or FlattenedPrompt.
|
|
||||||
Representing an "original" word sequence that supplies feature vectors for an initial diffusion operation, and an
|
|
||||||
"edited" word sequence, to which the attention maps produced by the "original" word sequence are applied. Intuitively,
|
|
||||||
the result should be an "edited" image that looks like the "original" image with concepts swapped.
|
|
||||||
|
|
||||||
eg "a cat sitting on a car" (original) -> "a smiling dog sitting on a car" (edited): the edited image should look
|
|
||||||
almost exactly the same as the original, but with a smiling dog rendered in place of the cat. The
|
|
||||||
CrossAttentionControlSubstitute object representing this swap may be confined to the tokens being swapped:
|
|
||||||
CrossAttentionControlSubstitute(original=[Fragment('cat')], edited=[Fragment('dog')])
|
|
||||||
or it may represent a larger portion of the token sequence:
|
|
||||||
CrossAttentionControlSubstitute(original=[Fragment('a cat sitting on a car')],
|
|
||||||
edited=[Fragment('a smiling dog sitting on a car')])
|
|
||||||
|
|
||||||
In either case expect it to be embedded in a Prompt or FlattenedPrompt:
|
|
||||||
FlattenedPrompt([
|
|
||||||
Fragment('a'),
|
|
||||||
CrossAttentionControlSubstitute(original=[Fragment('cat')], edited=[Fragment('dog')]),
|
|
||||||
Fragment('sitting on a car')
|
|
||||||
])
|
|
||||||
"""
|
|
||||||
def __init__(self, original: list, edited: list, options: dict=None):
|
|
||||||
self.original = original if len(original)>0 else [Fragment('')]
|
|
||||||
self.edited = edited if len(edited)>0 else [Fragment('')]
|
|
||||||
|
|
||||||
default_options = {
|
|
||||||
's_start': 0.0,
|
|
||||||
's_end': 0.2062994740159002, # ~= shape_freedom=0.5
|
|
||||||
't_start': 0.1,
|
|
||||||
't_end': 1.0
|
|
||||||
}
|
|
||||||
merged_options = default_options
|
|
||||||
if options is not None:
|
|
||||||
shape_freedom = options.pop('shape_freedom', None)
|
|
||||||
if shape_freedom is not None:
|
|
||||||
# high shape freedom = SD can do what it wants with the shape of the object
|
|
||||||
# high shape freedom => s_end = 0
|
|
||||||
# low shape freedom => s_end = 1
|
|
||||||
# shape freedom is in a "linear" space, while noticeable changes to s_end are typically closer around 0,
|
|
||||||
# and there is very little perceptible difference as s_end increases above 0.5
|
|
||||||
# so for shape_freedom = 0.5 we probably want s_end to be 0.2
|
|
||||||
# -> cube root and subtract from 1.0
|
|
||||||
merged_options['s_end'] = 1.0 - shape_freedom ** (1. / 3.)
|
|
||||||
#print('converted shape_freedom argument to', merged_options)
|
|
||||||
merged_options.update(options)
|
|
||||||
|
|
||||||
self.options = merged_options
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"CrossAttentionControlSubstitute:({self.original}->{self.edited} ({self.options})"
|
|
||||||
def __eq__(self, other):
|
|
||||||
return type(other) is CrossAttentionControlSubstitute \
|
|
||||||
and other.original == self.original \
|
|
||||||
and other.edited == self.edited \
|
|
||||||
and other.options == self.options
|
|
||||||
|
|
||||||
|
|
||||||
class CrossAttentionControlAppend(CrossAttentionControlledFragment):
|
|
||||||
def __init__(self, fragment: Fragment):
|
|
||||||
self.fragment = fragment
|
|
||||||
def __repr__(self):
|
|
||||||
return "CrossAttentionControlAppend:",self.fragment
|
|
||||||
def __eq__(self, other):
|
|
||||||
return type(other) is CrossAttentionControlAppend \
|
|
||||||
and other.fragment == self.fragment
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Conjunction():
|
|
||||||
"""
|
|
||||||
Storage for one or more Prompts or Blends, each of which is to be separately diffused and then the results merged
|
|
||||||
by weighted sum in latent space.
|
|
||||||
"""
|
|
||||||
def __init__(self, prompts: list, weights: list = None):
|
|
||||||
# force everything to be a Prompt
|
|
||||||
#print("making conjunction with", prompts, "types", [type(p).__name__ for p in prompts])
|
|
||||||
self.prompts = [x if (type(x) is Prompt
|
|
||||||
or type(x) is Blend
|
|
||||||
or type(x) is FlattenedPrompt)
|
|
||||||
else Prompt(x) for x in prompts]
|
|
||||||
self.weights = [1.0]*len(self.prompts) if (weights is None or len(weights)==0) else list(weights)
|
|
||||||
if len(self.weights) != len(self.prompts):
|
|
||||||
raise PromptParser.ParsingException(f"while parsing Conjunction: mismatched parts/weights counts {prompts}, {weights}")
|
|
||||||
self.type = 'AND'
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"Conjunction:{self.prompts} | weights {self.weights}"
|
|
||||||
def __eq__(self, other):
|
|
||||||
return type(other) is Conjunction \
|
|
||||||
and other.prompts == self.prompts \
|
|
||||||
and other.weights == self.weights
|
|
||||||
|
|
||||||
|
|
||||||
class Blend():
|
|
||||||
"""
|
|
||||||
Stores a Blend of multiple Prompts. To apply, build feature vectors for each of the child Prompts and then perform a
|
|
||||||
weighted blend of the feature vectors to produce a single feature vector that is effectively a lerp between the
|
|
||||||
Prompts.
|
|
||||||
"""
|
|
||||||
def __init__(self, prompts: list, weights: list[float], normalize_weights: bool=True):
|
|
||||||
#print("making Blend with prompts", prompts, "and weights", weights)
|
|
||||||
weights = [1.0]*len(prompts) if (weights is None or len(weights)==0) else list(weights)
|
|
||||||
if len(prompts) != len(weights):
|
|
||||||
raise PromptParser.ParsingException(f"while parsing Blend: mismatched prompts/weights counts {prompts}, {weights}")
|
|
||||||
for p in prompts:
|
|
||||||
if type(p) is not Prompt and type(p) is not FlattenedPrompt:
|
|
||||||
raise(PromptParser.ParsingException(f"{type(p)} cannot be added to a Blend, only Prompts or FlattenedPrompts"))
|
|
||||||
for f in p.children:
|
|
||||||
if isinstance(f, CrossAttentionControlSubstitute):
|
|
||||||
raise(PromptParser.ParsingException(f"while parsing Blend: sorry, you cannot do .swap() as part of a Blend"))
|
|
||||||
|
|
||||||
# upcast all lists to Prompt objects
|
|
||||||
self.prompts = [x if (type(x) is Prompt or type(x) is FlattenedPrompt)
|
|
||||||
else Prompt(x)
|
|
||||||
for x in prompts]
|
|
||||||
self.prompts = prompts
|
|
||||||
self.weights = weights
|
|
||||||
self.normalize_weights = normalize_weights
|
|
||||||
|
|
||||||
@property
|
|
||||||
def wants_cross_attention_control(self):
|
|
||||||
# blends cannot cross-attention control
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return f"Blend:{self.prompts} | weights {' ' if self.normalize_weights else '(non-normalized) '}{self.weights}"
|
|
||||||
def __eq__(self, other):
|
|
||||||
return other.__repr__() == self.__repr__()
|
|
||||||
|
|
||||||
|
|
||||||
class PromptParser():
|
|
||||||
|
|
||||||
class ParsingException(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class UnrecognizedOperatorException(ParsingException):
|
|
||||||
def __init__(self, operator:str):
|
|
||||||
super().__init__("Unrecognized operator: " + operator)
|
|
||||||
|
|
||||||
def __init__(self, attention_plus_base=1.1, attention_minus_base=0.9):
|
|
||||||
|
|
||||||
self.conjunction, self.prompt = build_parser_syntax(attention_plus_base, attention_minus_base)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_conjunction(self, prompt: str) -> Conjunction:
|
|
||||||
'''
|
|
||||||
:param prompt: The prompt string to parse
|
|
||||||
:return: a Conjunction representing the parsed results.
|
|
||||||
'''
|
|
||||||
#print(f"!!parsing '{prompt}'")
|
|
||||||
|
|
||||||
if len(prompt.strip()) == 0:
|
|
||||||
return Conjunction(prompts=[FlattenedPrompt([('', 1.0)])], weights=[1.0])
|
|
||||||
|
|
||||||
root = self.conjunction.parse_string(prompt)
|
|
||||||
#print(f"'{prompt}' parsed to root", root)
|
|
||||||
#fused = fuse_fragments(parts)
|
|
||||||
#print("fused to", fused)
|
|
||||||
|
|
||||||
return self.flatten(root[0])
|
|
||||||
|
|
||||||
def parse_legacy_blend(self, text: str, skip_normalize: bool = False) -> Optional[Blend]:
|
|
||||||
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=skip_normalize)
|
|
||||||
if len(weighted_subprompts) <= 1:
|
|
||||||
return None
|
|
||||||
strings = [x[0] for x in weighted_subprompts]
|
|
||||||
weights = [x[1] for x in weighted_subprompts]
|
|
||||||
|
|
||||||
parsed_conjunctions = [self.parse_conjunction(x) for x in strings]
|
|
||||||
flattened_prompts = [x.prompts[0] for x in parsed_conjunctions]
|
|
||||||
|
|
||||||
return Blend(prompts=flattened_prompts, weights=weights, normalize_weights=not skip_normalize)
|
|
||||||
|
|
||||||
|
|
||||||
def flatten(self, root: Conjunction, verbose = False) -> Conjunction:
|
|
||||||
"""
|
|
||||||
Flattening a Conjunction traverses all of the nested tree-like structures in each of its Prompts or Blends,
|
|
||||||
producing from each of these walks a linear sequence of Fragment or CrossAttentionControlSubstitute objects
|
|
||||||
that can be readily tokenized without the need to walk a complex tree structure.
|
|
||||||
|
|
||||||
:param root: The Conjunction to flatten.
|
|
||||||
:return: A Conjunction containing the result of flattening each of the prompts in the passed-in root.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def fuse_fragments(items):
|
|
||||||
# print("fusing fragments in ", items)
|
|
||||||
result = []
|
|
||||||
for x in items:
|
|
||||||
if type(x) is CrossAttentionControlSubstitute:
|
|
||||||
original_fused = fuse_fragments(x.original)
|
|
||||||
edited_fused = fuse_fragments(x.edited)
|
|
||||||
result.append(CrossAttentionControlSubstitute(original_fused, edited_fused, options=x.options))
|
|
||||||
else:
|
|
||||||
last_weight = result[-1].weight \
|
|
||||||
if (len(result) > 0 and not issubclass(type(result[-1]), CrossAttentionControlledFragment)) \
|
|
||||||
else None
|
|
||||||
this_text = x.text
|
|
||||||
this_weight = x.weight
|
|
||||||
if last_weight is not None and last_weight == this_weight:
|
|
||||||
last_text = result[-1].text
|
|
||||||
result[-1] = Fragment(last_text + ' ' + this_text, last_weight)
|
|
||||||
else:
|
|
||||||
result.append(x)
|
|
||||||
return result
|
|
||||||
|
|
||||||
def flatten_internal(node, weight_scale, results, prefix):
|
|
||||||
verbose and print(prefix + "flattening", node, "...")
|
|
||||||
if type(node) is pp.ParseResults or type(node) is list:
|
|
||||||
for x in node:
|
|
||||||
results = flatten_internal(x, weight_scale, results, prefix+' pr ')
|
|
||||||
#print(prefix, " ParseResults expanded, results is now", results)
|
|
||||||
elif type(node) is Attention:
|
|
||||||
# if node.weight < 1:
|
|
||||||
# todo: inject a blend when flattening attention with weight <1"
|
|
||||||
for index,c in enumerate(node.children):
|
|
||||||
results = flatten_internal(c, weight_scale * node.weight, results, prefix + f" att{index} ")
|
|
||||||
elif type(node) is Fragment:
|
|
||||||
results += [Fragment(node.text, node.weight*weight_scale)]
|
|
||||||
elif type(node) is CrossAttentionControlSubstitute:
|
|
||||||
original = flatten_internal(node.original, weight_scale, [], prefix + ' CAo ')
|
|
||||||
edited = flatten_internal(node.edited, weight_scale, [], prefix + ' CAe ')
|
|
||||||
results += [CrossAttentionControlSubstitute(original, edited, options=node.options)]
|
|
||||||
elif type(node) is Blend:
|
|
||||||
flattened_subprompts = []
|
|
||||||
#print(" flattening blend with prompts", node.prompts, "weights", node.weights)
|
|
||||||
for prompt in node.prompts:
|
|
||||||
# prompt is a list
|
|
||||||
flattened_subprompts = flatten_internal(prompt, weight_scale, flattened_subprompts, prefix+'B ')
|
|
||||||
results += [Blend(prompts=flattened_subprompts, weights=node.weights, normalize_weights=node.normalize_weights)]
|
|
||||||
elif type(node) is Prompt:
|
|
||||||
#print(prefix + "about to flatten Prompt with children", node.children)
|
|
||||||
flattened_prompt = []
|
|
||||||
for child in node.children:
|
|
||||||
flattened_prompt = flatten_internal(child, weight_scale, flattened_prompt, prefix+'P ')
|
|
||||||
results += [FlattenedPrompt(parts=fuse_fragments(flattened_prompt))]
|
|
||||||
#print(prefix + "after flattening Prompt, results is", results)
|
|
||||||
else:
|
|
||||||
raise PromptParser.ParsingException(f"unhandled node type {type(node)} when flattening {node}")
|
|
||||||
verbose and print(prefix + "-> after flattening", type(node).__name__, "results is", results)
|
|
||||||
return results
|
|
||||||
|
|
||||||
verbose and print("flattening", root)
|
|
||||||
|
|
||||||
flattened_parts = []
|
|
||||||
for part in root.prompts:
|
|
||||||
flattened_parts += flatten_internal(part, 1.0, [], ' C| ')
|
|
||||||
|
|
||||||
verbose and print("flattened to", flattened_parts)
|
|
||||||
|
|
||||||
weights = root.weights
|
|
||||||
return Conjunction(flattened_parts, weights)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def build_parser_syntax(attention_plus_base: float, attention_minus_base: float):
|
|
||||||
def make_operator_object(x):
|
|
||||||
#print('making operator for', x)
|
|
||||||
target = x[0]
|
|
||||||
operator = x[1]
|
|
||||||
arguments = x[2]
|
|
||||||
if operator == '.attend':
|
|
||||||
weight_raw = arguments[0]
|
|
||||||
weight = 1.0
|
|
||||||
if type(weight_raw) is float or type(weight_raw) is int:
|
|
||||||
weight = weight_raw
|
|
||||||
elif type(weight_raw) is str:
|
|
||||||
base = attention_plus_base if weight_raw[0] == '+' else attention_minus_base
|
|
||||||
weight = pow(base, len(weight_raw))
|
|
||||||
return Attention(weight=weight, children=[x for x in x[0]])
|
|
||||||
elif operator == '.swap':
|
|
||||||
return CrossAttentionControlSubstitute(target, arguments, x.as_dict())
|
|
||||||
elif operator == '.blend':
|
|
||||||
prompts = [Prompt(p) for p in x[0]]
|
|
||||||
weights_raw = x[2]
|
|
||||||
normalize_weights = True
|
|
||||||
if len(weights_raw) > 0 and weights_raw[-1][0] == 'no_normalize':
|
|
||||||
normalize_weights = False
|
|
||||||
weights_raw = weights_raw[:-1]
|
|
||||||
weights = [float(w[0]) for w in weights_raw]
|
|
||||||
return Blend(prompts=prompts, weights=weights, normalize_weights=normalize_weights)
|
|
||||||
elif operator == '.and' or operator == '.add':
|
|
||||||
prompts = [Prompt(p) for p in x[0]]
|
|
||||||
weights = [float(w[0]) for w in x[2]]
|
|
||||||
return Conjunction(prompts=prompts, weights=weights)
|
|
||||||
|
|
||||||
raise PromptParser.UnrecognizedOperatorException(operator)
|
|
||||||
|
|
||||||
def parse_fragment_str(x, expression: pp.ParseExpression, in_quotes: bool = False, in_parens: bool = False):
|
|
||||||
#print(f"parsing fragment string for {x}")
|
|
||||||
fragment_string = x[0]
|
|
||||||
if len(fragment_string.strip()) == 0:
|
|
||||||
return Fragment('')
|
|
||||||
|
|
||||||
if in_quotes:
|
|
||||||
# escape unescaped quotes
|
|
||||||
fragment_string = fragment_string.replace('"', '\\"')
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = (expression + pp.StringEnd()).parse_string(fragment_string)
|
|
||||||
#print("parsed to", result)
|
|
||||||
return result
|
|
||||||
except pp.ParseException as e:
|
|
||||||
#print("parse_fragment_str couldn't parse prompt string:", e)
|
|
||||||
raise
|
|
||||||
|
|
||||||
# meaningful symbols
|
|
||||||
lparen = pp.Literal("(").suppress()
|
|
||||||
rparen = pp.Literal(")").suppress()
|
|
||||||
quote = pp.Literal('"').suppress()
|
|
||||||
comma = pp.Literal(",").suppress()
|
|
||||||
dot = pp.Literal(".").suppress()
|
|
||||||
equals = pp.Literal("=").suppress()
|
|
||||||
|
|
||||||
escaped_lparen = pp.Literal('\\(')
|
|
||||||
escaped_rparen = pp.Literal('\\)')
|
|
||||||
escaped_quote = pp.Literal('\\"')
|
|
||||||
escaped_comma = pp.Literal('\\,')
|
|
||||||
escaped_dot = pp.Literal('\\.')
|
|
||||||
escaped_plus = pp.Literal('\\+')
|
|
||||||
escaped_minus = pp.Literal('\\-')
|
|
||||||
escaped_equals = pp.Literal('\\=')
|
|
||||||
|
|
||||||
syntactic_symbols = {
|
|
||||||
'(': escaped_lparen,
|
|
||||||
')': escaped_rparen,
|
|
||||||
'"': escaped_quote,
|
|
||||||
',': escaped_comma,
|
|
||||||
'.': escaped_dot,
|
|
||||||
'+': escaped_plus,
|
|
||||||
'-': escaped_minus,
|
|
||||||
'=': escaped_equals,
|
|
||||||
}
|
|
||||||
syntactic_chars = "".join(syntactic_symbols.keys())
|
|
||||||
|
|
||||||
# accepts int or float notation, always maps to float
|
|
||||||
number = pp.pyparsing_common.real | \
|
|
||||||
pp.Combine(pp.Optional("-")+pp.Word(pp.nums)).set_parse_action(pp.token_map(float))
|
|
||||||
|
|
||||||
# for options
|
|
||||||
keyword = pp.Word(pp.alphanums + '_')
|
|
||||||
|
|
||||||
# a word that absolutely does not contain any meaningful syntax
|
|
||||||
non_syntax_word = pp.Combine(pp.OneOrMore(pp.MatchFirst([
|
|
||||||
pp.Or(syntactic_symbols.values()),
|
|
||||||
pp.one_of(['-', '+']) + pp.NotAny(pp.White() | pp.Char(syntactic_chars) | pp.StringEnd()),
|
|
||||||
# build character-by-character
|
|
||||||
pp.CharsNotIn(string.whitespace + syntactic_chars, exact=1)
|
|
||||||
])))
|
|
||||||
non_syntax_word.set_parse_action(lambda x: [Fragment(t) for t in x])
|
|
||||||
non_syntax_word.set_name('non_syntax_word')
|
|
||||||
non_syntax_word.set_debug(False)
|
|
||||||
|
|
||||||
# a word that can contain any character at all - greedily consumes syntax, so use with care
|
|
||||||
free_word = pp.CharsNotIn(string.whitespace).set_parse_action(lambda x: Fragment(x[0]))
|
|
||||||
free_word.set_name('free_word')
|
|
||||||
free_word.set_debug(False)
|
|
||||||
|
|
||||||
|
|
||||||
# ok here we go. forward declare some things..
|
|
||||||
attention = pp.Forward()
|
|
||||||
cross_attention_substitute = pp.Forward()
|
|
||||||
parenthesized_fragment = pp.Forward()
|
|
||||||
quoted_fragment = pp.Forward()
|
|
||||||
|
|
||||||
# the types of things that can go into a fragment, consisting of syntax-full and/or strictly syntax-free components
|
|
||||||
fragment_part_expressions = [
|
|
||||||
attention,
|
|
||||||
cross_attention_substitute,
|
|
||||||
parenthesized_fragment,
|
|
||||||
quoted_fragment,
|
|
||||||
non_syntax_word
|
|
||||||
]
|
|
||||||
# a fragment that is permitted to contain commas
|
|
||||||
fragment_including_commas = pp.ZeroOrMore(pp.MatchFirst(
|
|
||||||
fragment_part_expressions + [
|
|
||||||
pp.Literal(',').set_parse_action(lambda x: Fragment(x[0]))
|
|
||||||
]
|
|
||||||
))
|
|
||||||
# a fragment that is not permitted to contain commas
|
|
||||||
fragment_excluding_commas = pp.ZeroOrMore(pp.MatchFirst(
|
|
||||||
fragment_part_expressions
|
|
||||||
))
|
|
||||||
|
|
||||||
# a fragment in double quotes (may be nested)
|
|
||||||
quoted_fragment << pp.QuotedString(quote_char='"', esc_char=None, esc_quote='\\"')
|
|
||||||
quoted_fragment.set_parse_action(lambda x: parse_fragment_str(x, fragment_including_commas, in_quotes=True))
|
|
||||||
|
|
||||||
# a fragment inside parentheses (may be nested)
|
|
||||||
parenthesized_fragment << (lparen + fragment_including_commas + rparen)
|
|
||||||
parenthesized_fragment.set_name('parenthesized_fragment')
|
|
||||||
parenthesized_fragment.set_debug(False)
|
|
||||||
|
|
||||||
# a string of the form (<keyword>=<float|keyword> | <float> | <keyword>) where keyword is alphanumeric + '_'
|
|
||||||
option = pp.Group(pp.MatchFirst([
|
|
||||||
keyword + equals + (number | keyword), # option=value
|
|
||||||
number.copy().set_parse_action(pp.token_map(str)), # weight
|
|
||||||
keyword # flag
|
|
||||||
]))
|
|
||||||
# options for an operator, eg "s_start=0.1, 0.3, no_normalize"
|
|
||||||
options = pp.Dict(pp.Optional(pp.delimited_list(option)))
|
|
||||||
options.set_name('options')
|
|
||||||
options.set_debug(False)
|
|
||||||
|
|
||||||
# a fragment which can be used as the target for an operator - either quoted or in parentheses, or a bare vanilla word
|
|
||||||
potential_operator_target = (quoted_fragment | parenthesized_fragment | non_syntax_word)
|
|
||||||
|
|
||||||
# a fragment whose weight has been increased or decreased by a given amount
|
|
||||||
attention_weight_operator = pp.Word('+') | pp.Word('-') | number
|
|
||||||
attention_explicit = (
|
|
||||||
pp.Group(potential_operator_target)
|
|
||||||
+ pp.Literal('.attend')
|
|
||||||
+ lparen
|
|
||||||
+ pp.Group(attention_weight_operator)
|
|
||||||
+ rparen
|
|
||||||
)
|
|
||||||
attention_explicit.set_parse_action(make_operator_object)
|
|
||||||
attention_implicit = (
|
|
||||||
pp.Group(potential_operator_target)
|
|
||||||
+ pp.NotAny(pp.White()) # do not permit whitespace between term and operator
|
|
||||||
+ pp.Group(attention_weight_operator)
|
|
||||||
)
|
|
||||||
attention_implicit.set_parse_action(lambda x: make_operator_object([x[0], '.attend', x[1]]))
|
|
||||||
attention << (attention_explicit | attention_implicit)
|
|
||||||
attention.set_name('attention')
|
|
||||||
attention.set_debug(False)
|
|
||||||
|
|
||||||
# cross-attention control by swapping one fragment for another
|
|
||||||
cross_attention_substitute << (
|
|
||||||
pp.Group(potential_operator_target).set_name('ca-target').set_debug(False)
|
|
||||||
+ pp.Literal(".swap").set_name('ca-operator').set_debug(False)
|
|
||||||
+ lparen
|
|
||||||
+ pp.Group(fragment_excluding_commas).set_name('ca-replacement').set_debug(False)
|
|
||||||
+ pp.Optional(comma + options).set_name('ca-options').set_debug(False)
|
|
||||||
+ rparen
|
|
||||||
)
|
|
||||||
cross_attention_substitute.set_name('cross_attention_substitute')
|
|
||||||
cross_attention_substitute.set_debug(False)
|
|
||||||
cross_attention_substitute.set_parse_action(make_operator_object)
|
|
||||||
|
|
||||||
|
|
||||||
# an entire self-contained prompt, which can be used in a Blend or Conjunction
|
|
||||||
prompt = pp.ZeroOrMore(pp.MatchFirst([
|
|
||||||
cross_attention_substitute,
|
|
||||||
attention,
|
|
||||||
quoted_fragment,
|
|
||||||
parenthesized_fragment,
|
|
||||||
free_word,
|
|
||||||
pp.White().suppress()
|
|
||||||
]))
|
|
||||||
quoted_prompt = quoted_fragment.copy().set_parse_action(lambda x: parse_fragment_str(x, prompt, in_quotes=True))
|
|
||||||
|
|
||||||
|
|
||||||
# a blend/lerp between the feature vectors for two or more prompts
|
|
||||||
blend = (
|
|
||||||
lparen
|
|
||||||
+ pp.Group(pp.delimited_list(pp.Group(potential_operator_target | quoted_prompt), min=1)).set_name('bl-target').set_debug(False)
|
|
||||||
+ rparen
|
|
||||||
+ pp.Literal(".blend").set_name('bl-operator').set_debug(False)
|
|
||||||
+ lparen
|
|
||||||
+ pp.Group(options).set_name('bl-options').set_debug(False)
|
|
||||||
+ rparen
|
|
||||||
)
|
|
||||||
blend.set_name('blend')
|
|
||||||
blend.set_debug(False)
|
|
||||||
blend.set_parse_action(make_operator_object)
|
|
||||||
|
|
||||||
# an operator to direct stable diffusion to step multiple times, once for each target, and then add the results together with different weights
|
|
||||||
explicit_conjunction = (
|
|
||||||
lparen
|
|
||||||
+ pp.Group(pp.delimited_list(pp.Group(potential_operator_target | quoted_prompt), min=1)).set_name('cj-target').set_debug(False)
|
|
||||||
+ rparen
|
|
||||||
+ pp.one_of([".and", ".add"]).set_name('cj-operator').set_debug(False)
|
|
||||||
+ lparen
|
|
||||||
+ pp.Group(options).set_name('cj-options').set_debug(False)
|
|
||||||
+ rparen
|
|
||||||
)
|
|
||||||
explicit_conjunction.set_name('explicit_conjunction')
|
|
||||||
explicit_conjunction.set_debug(False)
|
|
||||||
explicit_conjunction.set_parse_action(make_operator_object)
|
|
||||||
|
|
||||||
# by default a prompt consists of a Conjunction with a single term
|
|
||||||
implicit_conjunction = (blend | pp.Group(prompt)) + pp.StringEnd()
|
|
||||||
implicit_conjunction.set_parse_action(lambda x: Conjunction(x))
|
|
||||||
|
|
||||||
conjunction = (explicit_conjunction | implicit_conjunction)
|
|
||||||
|
|
||||||
return conjunction, prompt
|
|
||||||
|
|
||||||
|
|
||||||
def split_weighted_subprompts(text, skip_normalize=False)->list:
|
|
||||||
"""
|
|
||||||
Legacy blend parsing.
|
|
||||||
|
|
||||||
grabs all text up to the first occurrence of ':'
|
|
||||||
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
|
|
||||||
if ':' has no value defined, defaults to 1.0
|
|
||||||
repeats until no text remaining
|
|
||||||
"""
|
|
||||||
prompt_parser = re.compile("""
|
|
||||||
(?P<prompt> # capture group for 'prompt'
|
|
||||||
(?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:'
|
|
||||||
) # end 'prompt'
|
|
||||||
(?: # non-capture group
|
|
||||||
:+ # match one or more ':' characters
|
|
||||||
(?P<weight> # capture group for 'weight'
|
|
||||||
-?\d+(?:\.\d+)? # match positive or negative integer or decimal number
|
|
||||||
)? # end weight capture group, make optional
|
|
||||||
\s* # strip spaces after weight
|
|
||||||
| # OR
|
|
||||||
$ # else, if no ':' then match end of line
|
|
||||||
) # end non-capture group
|
|
||||||
""", re.VERBOSE)
|
|
||||||
parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float(
|
|
||||||
match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)]
|
|
||||||
if skip_normalize:
|
|
||||||
return parsed_prompts
|
|
||||||
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
|
|
||||||
if weight_sum == 0:
|
|
||||||
print(
|
|
||||||
"* Warning: Subprompt weights add up to zero. Discarding and using even weights instead.")
|
|
||||||
equal_weight = 1 / max(len(parsed_prompts), 1)
|
|
||||||
return [(x[0], equal_weight) for x in parsed_prompts]
|
|
||||||
return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
|
|
||||||
|
|
@ -1,3 +1,8 @@
|
|||||||
|
|
||||||
|
# adapted from bloc97's CrossAttentionControl colab
|
||||||
|
# https://github.com/bloc97/CrossAttentionControl
|
||||||
|
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
import math
|
import math
|
||||||
from typing import Optional, Callable
|
from typing import Optional, Callable
|
||||||
@ -6,35 +11,13 @@ import psutil
|
|||||||
import torch
|
import torch
|
||||||
import diffusers
|
import diffusers
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
from compel.cross_attention_control import Arguments
|
||||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||||
from diffusers.models.cross_attention import AttnProcessor
|
from diffusers.models.cross_attention import AttnProcessor
|
||||||
from ldm.invoke.devices import torch_dtype
|
from ldm.invoke.devices import torch_dtype
|
||||||
|
|
||||||
|
|
||||||
# adapted from bloc97's CrossAttentionControl colab
|
|
||||||
# https://github.com/bloc97/CrossAttentionControl
|
|
||||||
|
|
||||||
class Arguments:
|
|
||||||
def __init__(self, edited_conditioning: torch.Tensor, edit_opcodes: list[tuple], edit_options: dict):
|
|
||||||
"""
|
|
||||||
:param edited_conditioning: if doing cross-attention control, the edited conditioning [1 x 77 x 768]
|
|
||||||
:param edit_opcodes: if doing cross-attention control, a list of difflib.SequenceMatcher-like opcodes describing how to map original conditioning tokens to edited conditioning tokens (only the 'equal' opcode is required)
|
|
||||||
:param edit_options: if doing cross-attention control, per-edit options. there should be 1 item in edit_options for each item in edit_opcodes.
|
|
||||||
"""
|
|
||||||
# todo: rewrite this to take embedding fragments rather than a single edited_conditioning vector
|
|
||||||
self.edited_conditioning = edited_conditioning
|
|
||||||
self.edit_opcodes = edit_opcodes
|
|
||||||
|
|
||||||
if edited_conditioning is not None:
|
|
||||||
assert len(edit_opcodes) == len(edit_options), \
|
|
||||||
"there must be 1 edit_options dict for each edit_opcodes tuple"
|
|
||||||
non_none_edit_options = [x for x in edit_options if x is not None]
|
|
||||||
assert len(non_none_edit_options)>0, "missing edit_options"
|
|
||||||
if len(non_none_edit_options)>1:
|
|
||||||
print('warning: cross-attention control options are not working properly for >1 edit')
|
|
||||||
self.edit_options = non_none_edit_options[0]
|
|
||||||
|
|
||||||
|
|
||||||
class CrossAttentionType(enum.Enum):
|
class CrossAttentionType(enum.Enum):
|
||||||
SELF = 1
|
SELF = 1
|
||||||
TOKENS = 2
|
TOKENS = 2
|
||||||
@ -319,7 +302,6 @@ def override_cross_attention(model, context: Context, is_running_diffusers = Fal
|
|||||||
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
||||||
|
|
||||||
:param model: The unet model to inject into.
|
:param model: The unet model to inject into.
|
||||||
:param cross_attention_control_args: Arugments passeed to the CrossAttentionControl implementations
|
|
||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -523,7 +505,7 @@ from dataclasses import field, dataclass
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from diffusers.models.cross_attention import CrossAttention, CrossAttnProcessor, SlicedAttnProcessor, AttnProcessor
|
from diffusers.models.cross_attention import CrossAttention, CrossAttnProcessor, SlicedAttnProcessor
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -1,236 +0,0 @@
|
|||||||
import math
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from transformers import CLIPTokenizer, CLIPTextModel
|
|
||||||
|
|
||||||
from ldm.invoke.devices import torch_dtype
|
|
||||||
from ldm.modules.textual_inversion_manager import TextualInversionManager
|
|
||||||
|
|
||||||
|
|
||||||
class WeightedPromptFragmentsToEmbeddingsConverter():
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
tokenizer: CLIPTokenizer, # converts strings to lists of int token ids
|
|
||||||
text_encoder: CLIPTextModel, # convert a list of int token ids to a tensor of embeddings
|
|
||||||
textual_inversion_manager: TextualInversionManager = None
|
|
||||||
):
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.text_encoder = text_encoder
|
|
||||||
self.textual_inversion_manager = textual_inversion_manager
|
|
||||||
|
|
||||||
@property
|
|
||||||
def max_length(self):
|
|
||||||
return self.tokenizer.model_max_length
|
|
||||||
|
|
||||||
def get_embeddings_for_weighted_prompt_fragments(self,
|
|
||||||
text: list[list[str]],
|
|
||||||
fragment_weights: list[list[float]],
|
|
||||||
should_return_tokens: bool = False,
|
|
||||||
device='cpu'
|
|
||||||
) -> torch.Tensor:
|
|
||||||
'''
|
|
||||||
|
|
||||||
:param text: A list of fragments of text to which different weights are to be applied.
|
|
||||||
:param fragment_weights: A batch of lists of weights, one for each entry in `fragments`.
|
|
||||||
:return: A tensor of shape `[1, 77, token_dim]` containing weighted embeddings where token_dim is 768 for SD1
|
|
||||||
and 1280 for SD2
|
|
||||||
'''
|
|
||||||
if len(text) != len(fragment_weights):
|
|
||||||
raise ValueError(f"lengths of text and fragment_weights lists are not the same ({len(text)} != {len(fragment_weights)})")
|
|
||||||
|
|
||||||
batch_z = None
|
|
||||||
batch_tokens = None
|
|
||||||
for fragments, weights in zip(text, fragment_weights):
|
|
||||||
|
|
||||||
# First, weight tokens in individual fragments by scaling the feature vectors as requested (effectively
|
|
||||||
# applying a multiplier to the CFG scale on a per-token basis).
|
|
||||||
# For tokens weighted<1, intuitively we want SD to become not merely *less* interested in the concept
|
|
||||||
# captured by the fragment but actually *dis*interested in it (a 0.01 interest in "red" is still an active
|
|
||||||
# interest, however small, in redness; what the user probably intends when they attach the number 0.01 to
|
|
||||||
# "red" is to tell SD that it should almost completely *ignore* redness).
|
|
||||||
# To do this, the embedding is lerped away from base_embedding in the direction of an embedding for a prompt
|
|
||||||
# string from which the low-weighted fragment has been simply removed. The closer the weight is to zero, the
|
|
||||||
# closer the resulting embedding is to an embedding for a prompt that simply lacks this fragment.
|
|
||||||
|
|
||||||
# handle weights >=1
|
|
||||||
tokens, per_token_weights = self.get_token_ids_and_expand_weights(fragments, weights, device=device)
|
|
||||||
base_embedding = self.build_weighted_embedding_tensor(tokens, per_token_weights)
|
|
||||||
|
|
||||||
# this is our starting point
|
|
||||||
embeddings = base_embedding.unsqueeze(0)
|
|
||||||
per_embedding_weights = [1.0]
|
|
||||||
|
|
||||||
# now handle weights <1
|
|
||||||
# Do this by building extra embeddings tensors that lack the words being <1 weighted. These will be lerped
|
|
||||||
# with the embeddings tensors that have the words, such that if the weight of a word is 0.5, the resulting
|
|
||||||
# embedding will be exactly half-way between the unweighted prompt and the prompt with the <1 weighted words
|
|
||||||
# removed.
|
|
||||||
# eg for "mountain:1 man:0.5", intuitively the "man" should be "half-gone". therefore, append an embedding
|
|
||||||
# for "mountain" (i.e. without "man") to the already-produced embedding for "mountain man", and weight it
|
|
||||||
# such that the resulting lerped embedding is exactly half-way between "mountain man" and "mountain".
|
|
||||||
for index, fragment_weight in enumerate(weights):
|
|
||||||
if fragment_weight < 1:
|
|
||||||
fragments_without_this = fragments[:index] + fragments[index+1:]
|
|
||||||
weights_without_this = weights[:index] + weights[index+1:]
|
|
||||||
tokens, per_token_weights = self.get_token_ids_and_expand_weights(fragments_without_this, weights_without_this, device=device)
|
|
||||||
embedding_without_this = self.build_weighted_embedding_tensor(tokens, per_token_weights)
|
|
||||||
|
|
||||||
embeddings = torch.cat((embeddings, embedding_without_this.unsqueeze(0)), dim=1)
|
|
||||||
# weight of the embedding *without* this fragment gets *stronger* as its weight approaches 0
|
|
||||||
# if fragment_weight = 0, basically we want embedding_without_this to completely overwhelm base_embedding
|
|
||||||
# therefore:
|
|
||||||
# fragment_weight = 1: we are at base_z => lerp weight 0
|
|
||||||
# fragment_weight = 0.5: we are halfway between base_z and here => lerp weight 1
|
|
||||||
# fragment_weight = 0: we're now entirely overriding base_z ==> lerp weight inf
|
|
||||||
# so let's use tan(), because:
|
|
||||||
# tan is 0.0 at 0,
|
|
||||||
# 1.0 at PI/4, and
|
|
||||||
# inf at PI/2
|
|
||||||
# -> tan((1-weight)*PI/2) should give us ideal lerp weights
|
|
||||||
epsilon = 1e-9
|
|
||||||
fragment_weight = max(epsilon, fragment_weight) # inf is bad
|
|
||||||
embedding_lerp_weight = math.tan((1.0 - fragment_weight) * math.pi / 2)
|
|
||||||
# todo handle negative weight?
|
|
||||||
|
|
||||||
per_embedding_weights.append(embedding_lerp_weight)
|
|
||||||
|
|
||||||
lerped_embeddings = self.apply_embedding_weights(embeddings, per_embedding_weights, normalize=True).squeeze(0)
|
|
||||||
|
|
||||||
#print(f"assembled tokens for '{fragments}' into tensor of shape {lerped_embeddings.shape}")
|
|
||||||
|
|
||||||
# append to batch
|
|
||||||
batch_z = lerped_embeddings.unsqueeze(0) if batch_z is None else torch.cat([batch_z, lerped_embeddings.unsqueeze(0)], dim=1)
|
|
||||||
batch_tokens = tokens.unsqueeze(0) if batch_tokens is None else torch.cat([batch_tokens, tokens.unsqueeze(0)], dim=1)
|
|
||||||
|
|
||||||
# should have shape (B, 77, 768)
|
|
||||||
#print(f"assembled all tokens into tensor of shape {batch_z.shape}")
|
|
||||||
|
|
||||||
if should_return_tokens:
|
|
||||||
return batch_z, batch_tokens
|
|
||||||
else:
|
|
||||||
return batch_z
|
|
||||||
|
|
||||||
def get_token_ids(self, fragments: list[str], include_start_and_end_markers: bool = True) -> list[list[int]]:
|
|
||||||
"""
|
|
||||||
Convert a list of strings like `["a cat", "sitting", "on a mat"]` into a list of lists of token ids like
|
|
||||||
`[[bos, 0, 1, eos], [bos, 2, eos], [bos, 3, 0, 4, eos]]`. bos/eos markers are skipped if
|
|
||||||
`include_start_and_end_markers` is `False`. Each list will be restricted to the maximum permitted length
|
|
||||||
(typically 75 tokens + eos/bos markers).
|
|
||||||
|
|
||||||
:param fragments: The strings to convert.
|
|
||||||
:param include_start_and_end_markers:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
# for args documentation see ENCODE_KWARGS_DOCSTRING in tokenization_utils_base.py (in `transformers` lib)
|
|
||||||
token_ids_list = self.tokenizer(
|
|
||||||
fragments,
|
|
||||||
truncation=True,
|
|
||||||
max_length=self.max_length,
|
|
||||||
return_overflowing_tokens=False,
|
|
||||||
padding='do_not_pad',
|
|
||||||
return_tensors=None, # just give me lists of ints
|
|
||||||
)['input_ids']
|
|
||||||
|
|
||||||
result = []
|
|
||||||
for token_ids in token_ids_list:
|
|
||||||
# trim eos/bos
|
|
||||||
token_ids = token_ids[1:-1]
|
|
||||||
# pad for textual inversions with vector length >1
|
|
||||||
token_ids = self.textual_inversion_manager.expand_textual_inversion_token_ids_if_necessary(token_ids)
|
|
||||||
# restrict length to max_length-2 (leaving room for bos/eos)
|
|
||||||
token_ids = token_ids[0:self.max_length - 2]
|
|
||||||
# add back eos/bos if requested
|
|
||||||
if include_start_and_end_markers:
|
|
||||||
token_ids = [self.tokenizer.bos_token_id] + token_ids + [self.tokenizer.eos_token_id]
|
|
||||||
|
|
||||||
result.append(token_ids)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def apply_embedding_weights(self, embeddings: torch.Tensor, per_embedding_weights: list[float], normalize:bool) -> torch.Tensor:
|
|
||||||
per_embedding_weights = torch.tensor(per_embedding_weights, dtype=embeddings.dtype, device=embeddings.device)
|
|
||||||
if normalize:
|
|
||||||
per_embedding_weights = per_embedding_weights / torch.sum(per_embedding_weights)
|
|
||||||
reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1, 1,))
|
|
||||||
#reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1,1,)).expand(embeddings.shape)
|
|
||||||
return torch.sum(embeddings * reshaped_weights, dim=1)
|
|
||||||
# lerped embeddings has shape (77, 768)
|
|
||||||
|
|
||||||
|
|
||||||
def get_token_ids_and_expand_weights(self, fragments: list[str], weights: list[float], device: str) -> (torch.Tensor, torch.Tensor):
|
|
||||||
'''
|
|
||||||
Given a list of text fragments and corresponding weights: tokenize each fragment, append the token sequences
|
|
||||||
together and return a padded token sequence starting with the bos marker, ending with the eos marker, and padded
|
|
||||||
or truncated as appropriate to `self.max_length`. Also return a list of weights expanded from the passed-in
|
|
||||||
weights to match each token.
|
|
||||||
|
|
||||||
:param fragments: Text fragments to tokenize and concatenate. May be empty.
|
|
||||||
:param weights: Per-fragment weights (i.e. quasi-CFG scaling). Values from 0 to inf are permitted. In practise with SD1.5
|
|
||||||
values >1.6 tend to produce garbage output. Must have same length as `fragment`.
|
|
||||||
:return: A tuple of tensors `(token_ids, weights)`. `token_ids` is ints, `weights` is floats, both have shape `[self.max_length]`.
|
|
||||||
'''
|
|
||||||
if len(fragments) != len(weights):
|
|
||||||
raise ValueError(f"lengths of text and fragment_weights lists are not the same ({len(fragments)} != {len(weights)})")
|
|
||||||
|
|
||||||
# empty is meaningful
|
|
||||||
if len(fragments) == 0:
|
|
||||||
fragments = ['']
|
|
||||||
weights = [1.0]
|
|
||||||
per_fragment_token_ids = self.get_token_ids(fragments, include_start_and_end_markers=False)
|
|
||||||
all_token_ids = []
|
|
||||||
per_token_weights = []
|
|
||||||
#print("all fragments:", fragments, weights)
|
|
||||||
for this_fragment_token_ids, weight in zip(per_fragment_token_ids, weights):
|
|
||||||
# append
|
|
||||||
all_token_ids += this_fragment_token_ids
|
|
||||||
# fill out weights tensor with one float per token
|
|
||||||
per_token_weights += [float(weight)] * len(this_fragment_token_ids)
|
|
||||||
|
|
||||||
# leave room for bos/eos
|
|
||||||
max_token_count_without_bos_eos_markers = self.max_length - 2
|
|
||||||
if len(all_token_ids) > max_token_count_without_bos_eos_markers:
|
|
||||||
excess_token_count = len(all_token_ids) - max_token_count_without_bos_eos_markers
|
|
||||||
# TODO build nice description string of how the truncation was applied
|
|
||||||
# this should be done by calling self.tokenizer.convert_ids_to_tokens() then passing the result to
|
|
||||||
# self.tokenizer.convert_tokens_to_string() for the token_ids on each side of the truncation limit.
|
|
||||||
print(f">> Prompt is {excess_token_count} token(s) too long and has been truncated")
|
|
||||||
all_token_ids = all_token_ids[0:max_token_count_without_bos_eos_markers]
|
|
||||||
per_token_weights = per_token_weights[0:max_token_count_without_bos_eos_markers]
|
|
||||||
|
|
||||||
# pad out to a self.max_length-entry array: [bos_token, <prompt tokens>, eos_token, pad_token…]
|
|
||||||
# (typically self.max_length == 77)
|
|
||||||
all_token_ids = [self.tokenizer.bos_token_id] + all_token_ids + [self.tokenizer.eos_token_id]
|
|
||||||
per_token_weights = [1.0] + per_token_weights + [1.0]
|
|
||||||
pad_length = self.max_length - len(all_token_ids)
|
|
||||||
all_token_ids += [self.tokenizer.pad_token_id] * pad_length
|
|
||||||
per_token_weights += [1.0] * pad_length
|
|
||||||
|
|
||||||
all_token_ids_tensor = torch.tensor(all_token_ids, dtype=torch.long, device=device)
|
|
||||||
per_token_weights_tensor = torch.tensor(per_token_weights, dtype=torch_dtype(self.text_encoder.device), device=device)
|
|
||||||
#print(f"assembled all_token_ids_tensor with shape {all_token_ids_tensor.shape}")
|
|
||||||
return all_token_ids_tensor, per_token_weights_tensor
|
|
||||||
|
|
||||||
def build_weighted_embedding_tensor(self, token_ids: torch.Tensor, per_token_weights: torch.Tensor) -> torch.Tensor:
|
|
||||||
'''
|
|
||||||
Build a tensor that embeds the passed-in token IDs and applies the given per_token weights
|
|
||||||
:param token_ids: A tensor of shape `[self.max_length]` containing token IDs (ints)
|
|
||||||
:param per_token_weights: A tensor of shape `[self.max_length]` containing weights (floats)
|
|
||||||
:return: A tensor of shape `[1, self.max_length, token_dim]` representing the requested weighted embeddings
|
|
||||||
where `token_dim` is 768 for SD1 and 1280 for SD2.
|
|
||||||
'''
|
|
||||||
#print(f"building weighted embedding tensor for {tokens} with weights {per_token_weights}")
|
|
||||||
if token_ids.shape != torch.Size([self.max_length]):
|
|
||||||
raise ValueError(f"token_ids has shape {token_ids.shape} - expected [{self.max_length}]")
|
|
||||||
|
|
||||||
z = self.text_encoder(token_ids.unsqueeze(0), return_dict=False)[0]
|
|
||||||
empty_token_ids = torch.tensor([self.tokenizer.bos_token_id] +
|
|
||||||
[self.tokenizer.pad_token_id] * (self.max_length-2) +
|
|
||||||
[self.tokenizer.eos_token_id], dtype=torch.int, device=z.device).unsqueeze(0)
|
|
||||||
empty_z = self.text_encoder(empty_token_ids).last_hidden_state
|
|
||||||
batch_weights_expanded = per_token_weights.reshape(per_token_weights.shape + (1,)).expand(z.shape).to(z)
|
|
||||||
z_delta_from_empty = z - empty_z
|
|
||||||
weighted_z = empty_z + (z_delta_from_empty * batch_weights_expanded)
|
|
||||||
|
|
||||||
return weighted_z
|
|
@ -8,6 +8,7 @@ import torch
|
|||||||
from picklescan.scanner import scan_file_path
|
from picklescan.scanner import scan_file_path
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
|
from compel.embeddings_provider import BaseTextualInversionManager
|
||||||
from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
|
from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
|
||||||
|
|
||||||
|
|
||||||
@ -23,7 +24,7 @@ class TextualInversion:
|
|||||||
return self.embedding.shape[0]
|
return self.embedding.shape[0]
|
||||||
|
|
||||||
|
|
||||||
class TextualInversionManager:
|
class TextualInversionManager(BaseTextualInversionManager):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
tokenizer: CLIPTokenizer,
|
tokenizer: CLIPTokenizer,
|
||||||
@ -105,7 +106,7 @@ class TextualInversionManager:
|
|||||||
|
|
||||||
def _add_textual_inversion(
|
def _add_textual_inversion(
|
||||||
self, trigger_str, embedding, defer_injecting_tokens=False
|
self, trigger_str, embedding, defer_injecting_tokens=False
|
||||||
) -> TextualInversion:
|
) -> Optional[TextualInversion]:
|
||||||
"""
|
"""
|
||||||
Add a textual inversion to be recognised.
|
Add a textual inversion to be recognised.
|
||||||
:param trigger_str: The trigger text in the prompt that activates this textual inversion. If unknown to the embedder's tokenizer, will be added.
|
:param trigger_str: The trigger text in the prompt that activates this textual inversion. If unknown to the embedder's tokenizer, will be added.
|
||||||
|
@ -38,6 +38,7 @@ dependencies = [
|
|||||||
"albumentations",
|
"albumentations",
|
||||||
"click",
|
"click",
|
||||||
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
|
||||||
|
"compel",
|
||||||
"datasets",
|
"datasets",
|
||||||
"diffusers[torch]~=0.11",
|
"diffusers[torch]~=0.11",
|
||||||
"dnspython==2.2.1",
|
"dnspython==2.2.1",
|
||||||
|
@ -1,499 +0,0 @@
|
|||||||
import unittest
|
|
||||||
|
|
||||||
import pyparsing
|
|
||||||
|
|
||||||
from ldm.invoke.prompt_parser import PromptParser, Blend, Conjunction, FlattenedPrompt, CrossAttentionControlSubstitute, \
|
|
||||||
Fragment
|
|
||||||
|
|
||||||
|
|
||||||
def parse_prompt(prompt_string):
|
|
||||||
pp = PromptParser()
|
|
||||||
#print(f"parsing '{prompt_string}'")
|
|
||||||
parse_result = pp.parse_conjunction(prompt_string)
|
|
||||||
#print(f"-> parsed '{prompt_string}' to {parse_result}")
|
|
||||||
return parse_result
|
|
||||||
|
|
||||||
def make_basic_conjunction(strings: list[str]):
|
|
||||||
fragments = [Fragment(x) for x in strings]
|
|
||||||
return Conjunction([FlattenedPrompt(fragments)])
|
|
||||||
|
|
||||||
def make_weighted_conjunction(weighted_strings: list[tuple[str,float]]):
|
|
||||||
fragments = [Fragment(x, w) for x,w in weighted_strings]
|
|
||||||
return Conjunction([FlattenedPrompt(fragments)])
|
|
||||||
|
|
||||||
|
|
||||||
class PromptParserTestCase(unittest.TestCase):
|
|
||||||
|
|
||||||
def test_empty(self):
|
|
||||||
self.assertEqual(make_weighted_conjunction([('', 1)]), parse_prompt(''))
|
|
||||||
|
|
||||||
def test_basic(self):
|
|
||||||
self.assertEqual(make_weighted_conjunction([("fire flames", 1)]), parse_prompt("fire flames"))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('fire flames', 1)]), parse_prompt("fire (flames)"))
|
|
||||||
self.assertEqual(make_weighted_conjunction([("fire, flames", 1)]), parse_prompt("fire, flames"))
|
|
||||||
self.assertEqual(make_weighted_conjunction([("fire, flames , fire", 1)]), parse_prompt("fire, flames , fire"))
|
|
||||||
self.assertEqual(make_weighted_conjunction([("cat hot-dog eating", 1)]), parse_prompt("cat hot-dog eating"))
|
|
||||||
self.assertEqual(make_basic_conjunction(['Dalí']), parse_prompt("Dalí"))
|
|
||||||
|
|
||||||
def test_attention(self):
|
|
||||||
self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("(flames)0.5"))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("(flames).attend(0.5)"))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("flames.attend(0.5)"))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("\"flames\".attend(0.5)"))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('fire flames', 0.5)]), parse_prompt("(fire flames)0.5"))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('fire flames', 0.5)]), parse_prompt("(fire flames).attend(0.5)"))
|
|
||||||
|
|
||||||
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("(flames)+"))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("flames+"))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("\"flames\"+"))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("flames.attend(+)"))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("(flames).attend(+)"))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("\"flames\".attend(+)"))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("(flames)-"))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("flames-"))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("\"flames\"-"))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('fire', 1), ('flames', 0.5)]), parse_prompt("fire (flames)0.5"))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('fire', 1), ('flames', 0.5)]), parse_prompt("fire flames.attend(0.5)"))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('fire', 1), ('flames', 0.5)]), parse_prompt("fire (flames).attend(0.5)"))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('fire', 1), ('flames', 0.5)]), parse_prompt("fire \"flames\".attend(0.5)"))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('flames', pow(1.1, 2))]), parse_prompt("(flames)++"))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('flames', pow(0.9, 2))]), parse_prompt("(flames)--"))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('flowers', pow(0.9, 3)), ('flames', pow(1.1, 3))]), parse_prompt("(flowers)--- flames+++"))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('pretty flowers', 1.1)]),
|
|
||||||
parse_prompt("(pretty flowers)+"))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('pretty flowers', 1.1), (', the flames are too hot', 1)]),
|
|
||||||
parse_prompt("(pretty flowers)+, the flames are too hot"))
|
|
||||||
|
|
||||||
def test_no_parens_attention_runon(self):
|
|
||||||
self.assertEqual(make_weighted_conjunction([('fire', 1.0), ('flames', pow(1.1, 2))]), parse_prompt("fire flames++"))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('fire', 1.0), ('flames', pow(0.9, 2))]), parse_prompt("fire flames--"))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('flowers', 1.0), ('fire', pow(1.1, 2)), ('flames', 1.0)]), parse_prompt("flowers fire++ flames"))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('flowers', 1.0), ('fire', pow(0.9, 2)), ('flames', 1.0)]), parse_prompt("flowers fire-- flames"))
|
|
||||||
|
|
||||||
|
|
||||||
def test_explicit_conjunction(self):
|
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])]), parse_prompt('("fire", "flames").and(1,1)'))
|
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])]), parse_prompt('("fire", "flames").and()'))
|
|
||||||
self.assertEqual(
|
|
||||||
Conjunction([FlattenedPrompt([('fire flames', 1.0)]), FlattenedPrompt([('mountain man', 1.0)])]), parse_prompt('("fire flames", "mountain man").and()'))
|
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 2.0)]), FlattenedPrompt([('flames', 0.9)])]), parse_prompt('("(fire)2.0", "flames-").and()'))
|
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)]),
|
|
||||||
FlattenedPrompt([('mountain man', 1.0)])]), parse_prompt('("fire", "flames", "mountain man").and()'))
|
|
||||||
|
|
||||||
def test_conjunction_weights(self):
|
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])], weights=[2.0,1.0]), parse_prompt('("fire", "flames").and(2,1)'))
|
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])], weights=[1.0,2.0]), parse_prompt('("fire", "flames").and(1,2)'))
|
|
||||||
|
|
||||||
with self.assertRaises(PromptParser.ParsingException):
|
|
||||||
parse_prompt('("fire", "flames").and(2)')
|
|
||||||
parse_prompt('("fire", "flames").and(2,1,2)')
|
|
||||||
|
|
||||||
def test_complex_conjunction(self):
|
|
||||||
|
|
||||||
#print(parse_prompt("a person with a hat (riding a bicycle.swap(skateboard))++"))
|
|
||||||
|
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([("mountain man", 1.0)]), FlattenedPrompt([("a person with a hat", 1.0), ("riding a bicycle", pow(1.1,2))])], weights=[0.5, 0.5]),
|
|
||||||
parse_prompt("(\"mountain man\", \"a person with a hat (riding a bicycle)++\").and(0.5, 0.5)"))
|
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([("mountain man", 1.0)]),
|
|
||||||
FlattenedPrompt([("a person with a hat", 1.0),
|
|
||||||
("riding a", 1.1*1.1),
|
|
||||||
CrossAttentionControlSubstitute(
|
|
||||||
[Fragment("bicycle", pow(1.1,2))],
|
|
||||||
[Fragment("skateboard", pow(1.1,2))])
|
|
||||||
])
|
|
||||||
], weights=[0.5, 0.5]),
|
|
||||||
parse_prompt("(\"mountain man\", \"a person with a hat (riding a bicycle.swap(skateboard))++\").and(0.5, 0.5)"))
|
|
||||||
|
|
||||||
def test_badly_formed(self):
|
|
||||||
def make_untouched_prompt(prompt):
|
|
||||||
return Conjunction([FlattenedPrompt([(prompt, 1.0)])])
|
|
||||||
|
|
||||||
def assert_if_prompt_string_not_untouched(prompt):
|
|
||||||
self.assertEqual(make_untouched_prompt(prompt), parse_prompt(prompt))
|
|
||||||
|
|
||||||
assert_if_prompt_string_not_untouched('a test prompt')
|
|
||||||
assert_if_prompt_string_not_untouched('a badly formed +test prompt')
|
|
||||||
assert_if_prompt_string_not_untouched('a badly (formed test prompt')
|
|
||||||
|
|
||||||
#with self.assertRaises(pyparsing.ParseException):
|
|
||||||
assert_if_prompt_string_not_untouched('a badly (formed +test prompt')
|
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a badly formed +test prompt',1)])]) , parse_prompt('a badly (formed +test )prompt'))
|
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('(((a badly formed +test prompt',1)])]) , parse_prompt('(((a badly (formed +test )prompt'))
|
|
||||||
|
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('(a ba dly f ormed +test prompt',1)])]) , parse_prompt('(a (ba)dly (f)ormed +test prompt'))
|
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('(a ba dly f ormed +test +prompt',1)])]) , parse_prompt('(a (ba)dly (f)ormed +test +prompt'))
|
|
||||||
self.assertEqual(Conjunction([Blend([FlattenedPrompt([Fragment('((a badly (formed +test', 1)])], [1.0])]),
|
|
||||||
parse_prompt('("((a badly (formed +test ").blend(1.0)'))
|
|
||||||
|
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('hamburger bun', 1)])]),
|
|
||||||
parse_prompt("hamburger ((bun))"))
|
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('hamburger bun', 1)])]),
|
|
||||||
parse_prompt("hamburger (bun)"))
|
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('hamburger kaiser roll', 1)])]),
|
|
||||||
parse_prompt("hamburger (kaiser roll)"))
|
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('hamburger kaiser roll', 1)])]),
|
|
||||||
parse_prompt("hamburger ((kaiser roll))"))
|
|
||||||
|
|
||||||
|
|
||||||
def test_blend(self):
|
|
||||||
self.assertEqual(Conjunction(
|
|
||||||
[Blend([FlattenedPrompt([('mountain', 1.0)]), FlattenedPrompt([('man', 1.0)])], [1.0, 1.0])]),
|
|
||||||
parse_prompt("(\"mountain\", \"man\").blend()")
|
|
||||||
)
|
|
||||||
self.assertEqual(Conjunction(
|
|
||||||
[Blend([FlattenedPrompt([('mountain', 1.0)]), FlattenedPrompt([('man', 1.0)])], [1.0, 1.0])]),
|
|
||||||
parse_prompt("(mountain, man).blend()")
|
|
||||||
)
|
|
||||||
self.assertEqual(Conjunction(
|
|
||||||
[Blend([FlattenedPrompt([('mountain', 1.0)]), FlattenedPrompt([('man', 1.0)])], [1.0, 1.0])]),
|
|
||||||
parse_prompt("((mountain), (man)).blend()")
|
|
||||||
)
|
|
||||||
self.assertEqual(Conjunction(
|
|
||||||
[Blend([FlattenedPrompt([('mountain', 1.0)]), FlattenedPrompt([('tall man', 1.0)])], [1.0, 1.0])]),
|
|
||||||
parse_prompt("((mountain), (tall man)).blend()")
|
|
||||||
)
|
|
||||||
|
|
||||||
with self.assertRaises(PromptParser.ParsingException):
|
|
||||||
print(parse_prompt("((mountain), \"cat.swap(dog)\").blend()"))
|
|
||||||
|
|
||||||
self.assertEqual(Conjunction(
|
|
||||||
[Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('fire flames', 1.0)])], [0.7, 0.3])]),
|
|
||||||
parse_prompt("(\"fire\", \"fire flames\").blend(0.7, 0.3)")
|
|
||||||
)
|
|
||||||
self.assertEqual(Conjunction([Blend(
|
|
||||||
[FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('fire flames', 1.0)]), FlattenedPrompt([('hi', 1.0)])],
|
|
||||||
[0.7, 0.3, 1.0])]),
|
|
||||||
parse_prompt("(\"fire\", \"fire flames\", \"hi\").blend(0.7, 0.3, 1.0)")
|
|
||||||
)
|
|
||||||
self.assertEqual(Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]),
|
|
||||||
FlattenedPrompt([('fire flames', 1.0), ('hot', pow(1.1, 2))]),
|
|
||||||
FlattenedPrompt([('hi', 1.0)])],
|
|
||||||
weights=[0.7, 0.3, 1.0])]),
|
|
||||||
parse_prompt("(\"fire\", \"fire flames (hot)++\", \"hi\").blend(0.7, 0.3, 1.0)")
|
|
||||||
)
|
|
||||||
# blend a single entry is not a failure
|
|
||||||
self.assertEqual(Conjunction([Blend([FlattenedPrompt([('fire', 1.0)])], [0.7])]),
|
|
||||||
parse_prompt("(\"fire\").blend(0.7)")
|
|
||||||
)
|
|
||||||
# blend with empty
|
|
||||||
self.assertEqual(
|
|
||||||
Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]),
|
|
||||||
parse_prompt("(\"fire\", \"\").blend(0.7, 1)")
|
|
||||||
)
|
|
||||||
self.assertEqual(
|
|
||||||
Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]),
|
|
||||||
parse_prompt("(\"fire\", \" \").blend(0.7, 1)")
|
|
||||||
)
|
|
||||||
self.assertEqual(
|
|
||||||
Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]),
|
|
||||||
parse_prompt("(\"fire\", \" \").blend(0.7, 1)")
|
|
||||||
)
|
|
||||||
self.assertEqual(
|
|
||||||
Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([(',', 1.0)])], [0.7, 1.0])]),
|
|
||||||
parse_prompt("(\"fire\", \" , \").blend(0.7, 1)")
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertEqual(
|
|
||||||
Conjunction([Blend([FlattenedPrompt([('mountain , man , hairy', 1)]),
|
|
||||||
FlattenedPrompt([('face , teeth ,', 1), ('eyes', 0.9*0.9)])], weights=[1.0,-1.0], normalize_weights=True)]),
|
|
||||||
parse_prompt('("mountain, man, hairy", "face, teeth, eyes--").blend(1,-1)')
|
|
||||||
)
|
|
||||||
self.assertEqual(
|
|
||||||
Conjunction([Blend([FlattenedPrompt([('mountain , man , hairy', 1)]),
|
|
||||||
FlattenedPrompt([('face , teeth ,', 1), ('eyes', 0.9 * 0.9)])], weights=[1.0, -1.0], normalize_weights=False)]),
|
|
||||||
parse_prompt('("mountain, man, hairy", "face, teeth, eyes--").blend(1,-1,no_normalize)')
|
|
||||||
)
|
|
||||||
|
|
||||||
with self.assertRaises(PromptParser.ParsingException):
|
|
||||||
parse_prompt("(\"fire\", \"fire flames\").blend(0.7, 0.3, 0.1)")
|
|
||||||
with self.assertRaises(PromptParser.ParsingException):
|
|
||||||
parse_prompt("(\"fire\", \"fire flames\").blend(0.7)")
|
|
||||||
|
|
||||||
|
|
||||||
def test_nested(self):
|
|
||||||
self.assertEqual(make_weighted_conjunction([('fire', 1.0), ('flames', 2.0), ('trees', 3.0)]),
|
|
||||||
parse_prompt('fire (flames (trees)1.5)2.0'))
|
|
||||||
self.assertEqual(Conjunction([Blend(prompts=[FlattenedPrompt([('fire', 1.0), ('flames', 1.2100000000000002)]),
|
|
||||||
FlattenedPrompt([('mountain', 1.0), ('man', 2.0)])],
|
|
||||||
weights=[1.0, 1.0])]),
|
|
||||||
parse_prompt('("fire (flames)++", "mountain (man)2").blend(1,1)'))
|
|
||||||
|
|
||||||
def test_cross_attention_control(self):
|
|
||||||
|
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([CrossAttentionControlSubstitute([Fragment('sun')], [Fragment('moon')])])]),
|
|
||||||
parse_prompt("sun.swap(moon)"))
|
|
||||||
|
|
||||||
self.assertEqual(Conjunction([
|
|
||||||
FlattenedPrompt([Fragment('a', 1),
|
|
||||||
CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)]),
|
|
||||||
Fragment('eating a hotdog', 1)])]), parse_prompt("a \"cat\".swap(dog) eating a hotdog"))
|
|
||||||
|
|
||||||
self.assertEqual(Conjunction([
|
|
||||||
FlattenedPrompt([Fragment('a', 1),
|
|
||||||
CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)]),
|
|
||||||
Fragment('eating a hotdog', 1)])]), parse_prompt("a cat.swap(dog) eating a hotdog"))
|
|
||||||
|
|
||||||
|
|
||||||
fire_flames_to_trees = Conjunction([FlattenedPrompt([('fire', 1.0), \
|
|
||||||
CrossAttentionControlSubstitute([Fragment('flames', 1)], [Fragment('trees', 1)])])])
|
|
||||||
self.assertEqual(fire_flames_to_trees, parse_prompt('fire "flames".swap(trees)'))
|
|
||||||
self.assertEqual(fire_flames_to_trees, parse_prompt('fire (flames).swap(trees)'))
|
|
||||||
self.assertEqual(fire_flames_to_trees, parse_prompt('fire ("flames").swap(trees)'))
|
|
||||||
self.assertEqual(fire_flames_to_trees, parse_prompt('fire "flames".swap("trees")'))
|
|
||||||
self.assertEqual(fire_flames_to_trees, parse_prompt('fire (flames).swap("trees")'))
|
|
||||||
self.assertEqual(fire_flames_to_trees, parse_prompt('fire ("flames").swap("trees")'))
|
|
||||||
|
|
||||||
fire_flames_to_trees_and_houses = Conjunction([FlattenedPrompt([('fire', 1.0), \
|
|
||||||
CrossAttentionControlSubstitute([Fragment('flames', 1)], [Fragment('trees and houses', 1)])])])
|
|
||||||
self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire ("flames").swap("trees and houses")'))
|
|
||||||
self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire (flames).swap("trees and houses")'))
|
|
||||||
self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire "flames".swap("trees and houses")'))
|
|
||||||
|
|
||||||
trees_and_houses_to_flames = Conjunction([FlattenedPrompt([('fire', 1.0), \
|
|
||||||
CrossAttentionControlSubstitute([Fragment('trees and houses', 1)], [Fragment('flames',1)])])])
|
|
||||||
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire ("trees and houses").swap("flames")'))
|
|
||||||
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire (trees and houses).swap("flames")'))
|
|
||||||
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire "trees and houses".swap("flames")'))
|
|
||||||
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire ("trees and houses").swap(flames)'))
|
|
||||||
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire (trees and houses).swap(flames)'))
|
|
||||||
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire "trees and houses".swap(flames)'))
|
|
||||||
|
|
||||||
flames_to_trees_fire = Conjunction([FlattenedPrompt([
|
|
||||||
CrossAttentionControlSubstitute([Fragment('flames',1)], [Fragment('trees',1)]),
|
|
||||||
(', fire', 1.0)])])
|
|
||||||
self.assertEqual(flames_to_trees_fire, parse_prompt('"flames".swap("trees"), fire'))
|
|
||||||
self.assertEqual(flames_to_trees_fire, parse_prompt('(flames).swap("trees"), fire'))
|
|
||||||
self.assertEqual(flames_to_trees_fire, parse_prompt('("flames").swap("trees"), fire'))
|
|
||||||
self.assertEqual(flames_to_trees_fire, parse_prompt('"flames".swap(trees), fire'))
|
|
||||||
self.assertEqual(flames_to_trees_fire, parse_prompt('(flames).swap(trees), fire '))
|
|
||||||
self.assertEqual(flames_to_trees_fire, parse_prompt('("flames").swap(trees), fire '))
|
|
||||||
|
|
||||||
|
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
|
|
||||||
CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]),
|
|
||||||
parse_prompt('a forest landscape "".swap("in winter")'))
|
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
|
|
||||||
CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]),
|
|
||||||
parse_prompt('a forest landscape ().swap(in winter)'))
|
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
|
|
||||||
CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]),
|
|
||||||
parse_prompt('a forest landscape " ".swap("in winter")'))
|
|
||||||
|
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
|
|
||||||
CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]),
|
|
||||||
parse_prompt('a forest landscape "in winter".swap("")'))
|
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
|
|
||||||
CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]),
|
|
||||||
parse_prompt('a forest landscape "in winter".swap()'))
|
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
|
|
||||||
CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]),
|
|
||||||
parse_prompt('a forest landscape "in winter".swap(" ")'))
|
|
||||||
|
|
||||||
def test_cross_attention_control_with_attention(self):
|
|
||||||
flames_to_trees_fire = Conjunction([FlattenedPrompt([
|
|
||||||
CrossAttentionControlSubstitute([Fragment('flames',0.5)], [Fragment('trees',0.7)]),
|
|
||||||
Fragment(',', 1), Fragment('fire', 2.0)])])
|
|
||||||
self.assertEqual(flames_to_trees_fire, parse_prompt('"(flames)0.5".swap("(trees)0.7"), (fire)2.0'))
|
|
||||||
flames_to_trees_fire = Conjunction([FlattenedPrompt([
|
|
||||||
CrossAttentionControlSubstitute([Fragment('fire',0.5), Fragment('flames',0.25)], [Fragment('trees',0.7)]),
|
|
||||||
Fragment(',', 1), Fragment('fire', 2.0)])])
|
|
||||||
self.assertEqual(flames_to_trees_fire, parse_prompt('"(fire (flames)0.5)0.5".swap("(trees)0.7"), (fire)2.0'))
|
|
||||||
flames_to_trees_fire = Conjunction([FlattenedPrompt([
|
|
||||||
CrossAttentionControlSubstitute([Fragment('fire',0.5), Fragment('flames',0.25)], [Fragment('trees',0.7), Fragment('houses', 1)]),
|
|
||||||
Fragment(',', 1), Fragment('fire', 2.0)])])
|
|
||||||
self.assertEqual(flames_to_trees_fire, parse_prompt('"(fire (flames)0.5)0.5".swap("(trees)0.7 houses"), (fire)2.0'))
|
|
||||||
|
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1),
|
|
||||||
CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]),
|
|
||||||
Fragment('eating a', 1),
|
|
||||||
CrossAttentionControlSubstitute([Fragment('hotdog',1)], [Fragment('hotdog', pow(1.1,4))])
|
|
||||||
])]),
|
|
||||||
parse_prompt("a cat.swap(dog) eating a hotdog.swap(hotdog++++)"))
|
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1),
|
|
||||||
CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]),
|
|
||||||
Fragment('eating a', 1),
|
|
||||||
CrossAttentionControlSubstitute([Fragment('hotdog',1)], [Fragment('hotdog', pow(1.1,4))])
|
|
||||||
])]),
|
|
||||||
parse_prompt("a cat.swap(dog) eating a hotdog.swap(hotdog++++, shape_freedom=0.5)"))
|
|
||||||
|
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1),
|
|
||||||
CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]),
|
|
||||||
Fragment('eating a', 1),
|
|
||||||
CrossAttentionControlSubstitute([Fragment('hotdog',1)], [Fragment('hotdog', pow(1.1,4))])
|
|
||||||
])]),
|
|
||||||
parse_prompt("a cat.swap(dog) eating a hotdog.swap(\"hotdog++++\", shape_freedom=0.5)"))
|
|
||||||
|
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1),
|
|
||||||
CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]),
|
|
||||||
Fragment('eating a', 1),
|
|
||||||
CrossAttentionControlSubstitute([Fragment('hotdog',1)], [Fragment('h(o)tdog', pow(1.1,4))])
|
|
||||||
])]),
|
|
||||||
parse_prompt("a cat.swap(dog) eating a hotdog.swap(h\(o\)tdog++++, shape_freedom=0.5)"))
|
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1),
|
|
||||||
CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]),
|
|
||||||
Fragment('eating a', 1),
|
|
||||||
CrossAttentionControlSubstitute([Fragment('hotdog',1)], [Fragment('h(o)tdog', pow(1.1,4))])
|
|
||||||
])]),
|
|
||||||
parse_prompt("a cat.swap(dog) eating a hotdog.swap(\"h\(o\)tdog++++\", shape_freedom=0.5)"))
|
|
||||||
|
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a', 1),
|
|
||||||
CrossAttentionControlSubstitute([Fragment('cat',1)], [Fragment('dog',1)]),
|
|
||||||
Fragment('eating a', 1),
|
|
||||||
CrossAttentionControlSubstitute([Fragment('hotdog',1)], [Fragment('h(o)tdog', pow(0.9,1))])
|
|
||||||
])]),
|
|
||||||
parse_prompt("a cat.swap(dog) eating a hotdog.swap(h\(o\)tdog-, shape_freedom=0.5)"))
|
|
||||||
|
|
||||||
|
|
||||||
def test_cross_attention_control_options(self):
|
|
||||||
self.assertEqual(Conjunction([
|
|
||||||
FlattenedPrompt([Fragment('a', 1),
|
|
||||||
CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)], options={'s_start':0.1}),
|
|
||||||
Fragment('eating a hotdog', 1)])]),
|
|
||||||
parse_prompt("a \"cat\".swap(dog, s_start=0.1) eating a hotdog"))
|
|
||||||
self.assertEqual(Conjunction([
|
|
||||||
FlattenedPrompt([Fragment('a', 1),
|
|
||||||
CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)], options={'t_start':0.1}),
|
|
||||||
Fragment('eating a hotdog', 1)])]),
|
|
||||||
parse_prompt("a \"cat\".swap(dog, t_start=0.1) eating a hotdog"))
|
|
||||||
self.assertEqual(Conjunction([
|
|
||||||
FlattenedPrompt([Fragment('a', 1),
|
|
||||||
CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)], options={'s_start': 20.0, 't_start':0.1}),
|
|
||||||
Fragment('eating a hotdog', 1)])]),
|
|
||||||
parse_prompt("a \"cat\".swap(dog, t_start=0.1, s_start=20) eating a hotdog"))
|
|
||||||
|
|
||||||
self.assertEqual(
|
|
||||||
Conjunction([
|
|
||||||
FlattenedPrompt([Fragment('a fantasy forest landscape', 1),
|
|
||||||
CrossAttentionControlSubstitute([Fragment('', 1)], [Fragment('with a river', 1)],
|
|
||||||
options={'s_start': 0.8, 't_start': 0.8})])]),
|
|
||||||
parse_prompt("a fantasy forest landscape \"\".swap(with a river, s_start=0.8, t_start=0.8)"))
|
|
||||||
|
|
||||||
|
|
||||||
def test_escaping(self):
|
|
||||||
|
|
||||||
# make sure ", ( and ) can be escaped
|
|
||||||
|
|
||||||
self.assertEqual(make_basic_conjunction(['mountain (man)']),parse_prompt('mountain \(man\)'))
|
|
||||||
self.assertEqual(make_basic_conjunction(['mountain (man )']),parse_prompt('mountain (\(man)\)'))
|
|
||||||
self.assertEqual(make_basic_conjunction(['mountain (man)']),parse_prompt('mountain (\(man\))'))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('mountain', 1), ('(man)', 1.1)]), parse_prompt('mountain (\(man\))+'))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('mountain', 1), ('(man)', 1.1)]), parse_prompt('"mountain" (\(man\))+'))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('"mountain"', 1), ('(man)', 1.1)]), parse_prompt('\\"mountain\\" (\(man\))+'))
|
|
||||||
# same weights for each are combined into one
|
|
||||||
self.assertEqual(make_weighted_conjunction([('"mountain" (man)', 1.1)]), parse_prompt('(\\"mountain\\")+ (\(man\))+'))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('"mountain"', 1.1), ('(man)', 0.9)]), parse_prompt('(\\"mountain\\")+ (\(man\))-'))
|
|
||||||
|
|
||||||
self.assertEqual(make_weighted_conjunction([('mountain', 1), ('\(man\)', 1.1)]),parse_prompt('mountain (\(man\))1.1'))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('mountain', 1), ('\(man\)', 1.1)]),parse_prompt('"mountain" (\(man\))1.1'))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('"mountain"', 1), ('\(man\)', 1.1)]),parse_prompt('\\"mountain\\" (\(man\))1.1'))
|
|
||||||
# same weights for each are combined into one
|
|
||||||
self.assertEqual(make_weighted_conjunction([('\\"mountain\\" \(man\)', 1.1)]),parse_prompt('(\\"mountain\\")+ (\(man\))1.1'))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('\\"mountain\\"', 1.1), ('\(man\)', 0.9)]),parse_prompt('(\\"mountain\\")1.1 (\(man\))0.9'))
|
|
||||||
|
|
||||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain', 1.1), ('\(man\)', 1.1*1.1)]),parse_prompt('hairy (mountain (\(man\))+)+'))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('\(man\)', 1.1*1.1), ('mountain', 1.1)]),parse_prompt('hairy ((\(man\))1.1 "mountain")+'))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain', 1.1), ('\(man\)', 1.1*1.1)]),parse_prompt('hairy ("mountain" (\(man\))1.1 )+'))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , man', 1.1)]),parse_prompt('hairy ("mountain, man")+'))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , man with a', 1.1), ('beard', 1.1*1.1)]), parse_prompt('hairy ("mountain, man" with a beard+)+'))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , man with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, man" with a (beard)2.0)+'))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , \"man\" with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\"man\\"" with a (beard)2.0)+'))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , m\"an\" with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, m\\"an\\"" with a (beard)2.0)+'))
|
|
||||||
|
|
||||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , \"man (with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" \(with a (beard)2.0)+'))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , \"man w(ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" w\(ith a (beard)2.0)+'))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , \"man with( a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" with\( a (beard)2.0)+'))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , \"man )with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" \)with a (beard)2.0)+'))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , \"man w)ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" w\)ith a (beard)2.0)+'))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain , \"man with) a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" with\) a (beard)2.0)+'))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mou)ntain , \"man (wit(h a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mou\)ntain, \\\"man\" \(wit\(h a (beard)2.0)+'))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('hai(ry', 1), ('mountain , \"man w)ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hai\(ry ("mountain, \\\"man\" w\)ith a (beard)2.0)+'))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('hairy((', 1), ('mountain , \"man with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy\(\( ("mountain, \\\"man\" with a (beard)2.0)+'))
|
|
||||||
|
|
||||||
self.assertEqual(make_weighted_conjunction([('mountain , \"man (with a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" \(with a (beard)2.0)+ hairy'))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('mountain , \"man w(ith a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" w\(ith a (beard)2.0)+hairy'))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('mountain , \"man with( a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" with\( a (beard)2.0)+ hairy'))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('mountain , \"man )with a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" \)with a (beard)2.0)+ hairy'))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('mountain , \"man w)ith a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" w\)ith a (beard)2.0)+ hairy'))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('mountain , \"man with) a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt(' ("mountain, \\\"man\" with\) a (beard)2.0)+ hairy'))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('mou)ntain , \"man (wit(h a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mou\)ntain, \\\"man\" \(wit\(h a (beard)2.0)+ hairy'))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('mountain , \"man w)ith a', 1.1), ('beard', 1.1*2.0), ('hai(ry', 1)]), parse_prompt('("mountain, \\\"man\" w\)ith a (beard)2.0)+ hai\(ry '))
|
|
||||||
self.assertEqual(make_weighted_conjunction([('mountain , \"man with a', 1.1), ('beard', 1.1*2.0), ('hairy((', 1)]), parse_prompt('("mountain, \\\"man\" with a (beard)2.0)+ hairy\(\( '))
|
|
||||||
|
|
||||||
def test_cross_attention_escaping(self):
|
|
||||||
|
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('monkey', 1)])])]),
|
|
||||||
parse_prompt('mountain (man).swap(monkey)'))
|
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('m(onkey', 1)])])]),
|
|
||||||
parse_prompt('mountain (man).swap(m\(onkey)'))
|
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('m(an', 1)], [Fragment('m(onkey', 1)])])]),
|
|
||||||
parse_prompt('mountain (m\(an).swap(m\(onkey)'))
|
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('(((', 1)], [Fragment('m(on))key', 1)])])]),
|
|
||||||
parse_prompt('mountain (\(\(\().swap(m\(on\)\)key)'))
|
|
||||||
|
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('monkey', 1)])])]),
|
|
||||||
parse_prompt('mountain ("man").swap(monkey)'))
|
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('monkey', 1)])])]),
|
|
||||||
parse_prompt('mountain ("man").swap("monkey")'))
|
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('"man', 1)], [Fragment('monkey', 1)])])]),
|
|
||||||
parse_prompt('mountain (\\"man).swap("monkey")'))
|
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('m(onkey', 1)])])]),
|
|
||||||
parse_prompt('mountain (man).swap(m\(onkey)'))
|
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('m(an', 1)], [Fragment('m(onkey', 1)])])]),
|
|
||||||
parse_prompt('mountain (m\(an).swap(m\(onkey)'))
|
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('(((', 1)], [Fragment('m(on))key', 1)])])]),
|
|
||||||
parse_prompt('mountain (\(\(\().swap(m\(on\)\)key)'))
|
|
||||||
|
|
||||||
def test_legacy_blend(self):
|
|
||||||
pp = PromptParser()
|
|
||||||
|
|
||||||
self.assertEqual(Blend([FlattenedPrompt([('mountain man', 1)]),
|
|
||||||
FlattenedPrompt([('man mountain', 1)])],
|
|
||||||
weights=[0.5,0.5]),
|
|
||||||
pp.parse_legacy_blend('mountain man:1 man mountain:1'))
|
|
||||||
|
|
||||||
self.assertEqual(Blend([FlattenedPrompt([('mountain', 1.1), ('man', 1)]),
|
|
||||||
FlattenedPrompt([('man', 1), ('mountain', 0.9)])],
|
|
||||||
weights=[0.5,0.5]),
|
|
||||||
pp.parse_legacy_blend('mountain+ man:1 man mountain-:1'))
|
|
||||||
|
|
||||||
self.assertEqual(Blend([FlattenedPrompt([('mountain', 1.1), ('man', 1)]),
|
|
||||||
FlattenedPrompt([('man', 1), ('mountain', 0.9)])],
|
|
||||||
weights=[0.5,0.5]),
|
|
||||||
pp.parse_legacy_blend('mountain+ man:1 man mountain-'))
|
|
||||||
|
|
||||||
self.assertEqual(Blend([FlattenedPrompt([('mountain', 1.1), ('man', 1)]),
|
|
||||||
FlattenedPrompt([('man', 1), ('mountain', 0.9)])],
|
|
||||||
weights=[0.5,0.5]),
|
|
||||||
pp.parse_legacy_blend('mountain+ man: man mountain-:'))
|
|
||||||
|
|
||||||
self.assertEqual(Blend([FlattenedPrompt([('mountain man', 1)]),
|
|
||||||
FlattenedPrompt([('man mountain', 1)])],
|
|
||||||
weights=[0.75,0.25]),
|
|
||||||
pp.parse_legacy_blend('mountain man:3 man mountain:1'))
|
|
||||||
|
|
||||||
self.assertEqual(Blend([FlattenedPrompt([('mountain man', 1)]),
|
|
||||||
FlattenedPrompt([('man mountain', 1)])],
|
|
||||||
weights=[1.0,0.0]),
|
|
||||||
pp.parse_legacy_blend('mountain man:3 man mountain:0'))
|
|
||||||
|
|
||||||
self.assertEqual(Blend([FlattenedPrompt([('mountain man', 1)]),
|
|
||||||
FlattenedPrompt([('man mountain', 1)])],
|
|
||||||
weights=[0.8,0.2]),
|
|
||||||
pp.parse_legacy_blend('"mountain man":4 man mountain'))
|
|
||||||
|
|
||||||
|
|
||||||
def test_single(self):
|
|
||||||
self.assertEqual(Conjunction([FlattenedPrompt([("mountain man", 1.0)]),
|
|
||||||
FlattenedPrompt([("a person with a hat", 1.0),
|
|
||||||
("riding a", 1.1*1.1),
|
|
||||||
CrossAttentionControlSubstitute(
|
|
||||||
[Fragment("bicycle", pow(1.1,2))],
|
|
||||||
[Fragment("skateboard", pow(1.1,2))])
|
|
||||||
])
|
|
||||||
], weights=[0.5, 0.5]),
|
|
||||||
parse_prompt("(\"mountain man\", \"a person with a hat (riding a bicycle.swap(skateboard))++\").and(0.5, 0.5)"))
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
unittest.main()
|
|
Loading…
Reference in New Issue
Block a user