mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'fix-prompts' of https://github.com/damian0815/InvokeAI into merge-prompt-and-inpaint-model
This commit is contained in:
commit
2f1c1e7695
@ -14,7 +14,7 @@ from threading import Event
|
||||
|
||||
from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash
|
||||
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata
|
||||
from ldm.invoke.conditioning import split_weighted_subprompts
|
||||
from ldm.invoke.prompt_parser import split_weighted_subprompts
|
||||
|
||||
from backend.modules.parameters import parameters_to_command
|
||||
|
||||
|
@ -33,7 +33,7 @@ from ldm.generate import Generate
|
||||
from ldm.invoke.restoration import Restoration
|
||||
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata
|
||||
from ldm.invoke.args import APP_ID, APP_VERSION, calculate_init_img_hash
|
||||
from ldm.invoke.conditioning import split_weighted_subprompts
|
||||
from ldm.invoke.prompt_parser import split_weighted_subprompts
|
||||
|
||||
from modules.parameters import parameters_to_command
|
||||
|
||||
|
@ -76,4 +76,4 @@ model:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
target: ldm.modules.encoders.modules.WeightedFrozenCLIPEmbedder
|
||||
|
@ -84,6 +84,48 @@ Getting close - but there's no sense in having a saddle when our horse doesn't h
|
||||
|
||||
---
|
||||
|
||||
## **Prompt Syntax Features**
|
||||
|
||||
The InvokeAI prompting language has the following features:
|
||||
|
||||
### Attention weighting
|
||||
Append a word or phrase with `-` or `+`, or a weight between `0` and `2` (`1`=default), to decrease or increase "attention" (= a mix of per-token CFG weighting multiplier and, for `-`, a weighted blend with the prompt without the term).
|
||||
|
||||
The following will be recognised:
|
||||
* single words without parentheses: `a tall thin man picking apricots+`
|
||||
* single or multiple words with parentheses: `a tall thin man picking (apricots)+` `a tall thin man picking (apricots)-` `a tall thin man (picking apricots)+` `a tall thin man (picking apricots)-`
|
||||
* more effect with more symbols `a tall thin man (picking apricots)++`
|
||||
* nesting `a tall thin man (picking apricots+)++` (`apricots` effectively gets `+++`)
|
||||
* all of the above with explicit numbers `a tall thin man picking (apricots)1.1` `a tall thin man (picking (apricots)1.3)1.1`. (`+` is equivalent to 1.1, `++` is pow(1.1,2), `+++` is pow(1.1,3), etc; `-` means 0.9, `--` means pow(0.9,2), etc.)
|
||||
* attention also applies to `[unconditioning]` so `a tall thin man picking apricots [(ladder)0.01]` will *very gently* nudge SD away from trying to draw the man on a ladder
|
||||
|
||||
### Blending between prompts
|
||||
|
||||
* `("a tall thin man picking apricots", "a tall thin man picking pears").blend(1,1)`
|
||||
* The existing prompt blending using `:<weight>` will continue to be supported - `("a tall thin man picking apricots", "a tall thin man picking pears").blend(1,1)` is equivalent to `a tall thin man picking apricots:1 a tall thin man picking pears:1` in the old syntax.
|
||||
* Attention weights can be nested inside blends.
|
||||
* Non-normalized blends are supported by passing `no_normalize` as an additional argument to the blend weights, eg `("a tall thin man picking apricots", "a tall thin man picking pears").blend(1,-1,no_normalize)`. very fun to explore local maxima in the feature space, but also easy to produce garbage output.
|
||||
|
||||
See the section below on "Prompt Blending" for more information about how this works.
|
||||
|
||||
### Cross-Attention Control ('prompt2prompt')
|
||||
|
||||
Denoise with a given prompt and then re-use the attention→pixel maps to substitute words in the original prompt for words in a new prompt. Based off [bloc97's colab](https://github.com/bloc97/CrossAttentionControl).
|
||||
|
||||
* `a ("fluffy cat").swap("smiling dog") eating a hotdog`.
|
||||
* quotes optional: `a (fluffy cat).swap(smiling dog) eating a hotdog`.
|
||||
* for single word substitutions parentheses are also optional: `a cat.swap(dog) eating a hotdog`.
|
||||
* Supports options `s_start`, `s_end`, `t_start`, `t_end` (each 0-1) loosely corresponding to bloc97's `prompt_edit_spatial_start/_end` and `prompt_edit_tokens_start/_end` but with the math swapped to make it easier to intuitively understand.
|
||||
* Example usage:`a (cat).swap(dog, s_end=0.3) eating a hotdog` - the `s_end` argument means that the "spatial" (self-attention) edit will stop having any effect after 30% (=0.3) of the steps have been done, leaving Stable Diffusion with 70% of the steps where it is free to decide for itself how to reshape the cat-form into a dog form.
|
||||
* The numbers represent a percentage through the step sequence where the edits should happen. 0 means the start (noisy starting image), 1 is the end (final image).
|
||||
* For img2img, the step sequence does not start at 0 but instead at (1-strength) - so if strength is 0.7, s_start and s_end must both be greater than 0.3 (1-0.7) to have any effect.
|
||||
* Convenience option `shape_freedom` (0-1) to specify how much "freedom" Stable Diffusion should have to change the shape of the subject being swapped.
|
||||
* `a (cat).swap(dog, shape_freedom=0.5) eating a hotdog`.
|
||||
|
||||
### Escaping parantheses () and speech marks ""
|
||||
|
||||
If the model you are using has parentheses () or speech marks "" as part of its syntax, you will need to "escape" these using a backslash, so that`(my_keyword)` becomes `\(my_keyword\)`. Otherwise, the prompt parser will attempt to interpret the parentheses as part of the prompt syntax and it will get confused.
|
||||
|
||||
## **Prompt Blending**
|
||||
|
||||
You may blend together different sections of the prompt to explore the
|
||||
|
@ -1,5 +1,5 @@
|
||||
# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
|
||||
|
||||
import pyparsing
|
||||
# Derived from source code carrying the following copyrights
|
||||
# Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich
|
||||
# Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors
|
||||
@ -24,6 +24,7 @@ from PIL import Image, ImageOps
|
||||
from torch import nn
|
||||
from pytorch_lightning import seed_everything, logging
|
||||
|
||||
from ldm.invoke.prompt_parser import PromptParser
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
@ -32,7 +33,7 @@ from ldm.invoke.pngwriter import PngWriter
|
||||
from ldm.invoke.args import metadata_from_png
|
||||
from ldm.invoke.image_util import InitImageResizer
|
||||
from ldm.invoke.devices import choose_torch_device, choose_precision
|
||||
from ldm.invoke.conditioning import get_uc_and_c
|
||||
from ldm.invoke.conditioning import get_uc_and_c_and_ec
|
||||
from ldm.invoke.model_cache import ModelCache
|
||||
from ldm.invoke.seamless import configure_model_padding
|
||||
from ldm.invoke.txt2mask import Txt2Mask, SegmentedGrayscale
|
||||
@ -404,7 +405,7 @@ class Generate:
|
||||
mask_image = None
|
||||
|
||||
try:
|
||||
uc, c = get_uc_and_c(
|
||||
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(
|
||||
prompt, model =self.model,
|
||||
skip_normalize=skip_normalize,
|
||||
log_tokens =self.log_tokenization
|
||||
@ -448,7 +449,7 @@ class Generate:
|
||||
sampler=self.sampler,
|
||||
steps=steps,
|
||||
cfg_scale=cfg_scale,
|
||||
conditioning=(uc, c),
|
||||
conditioning=(uc, c, extra_conditioning_info),
|
||||
ddim_eta=ddim_eta,
|
||||
image_callback=image_callback, # called after the final image is generated
|
||||
step_callback=step_callback, # called after each intermediate image is generated
|
||||
@ -481,14 +482,14 @@ class Generate:
|
||||
save_original = save_original,
|
||||
image_callback = image_callback)
|
||||
|
||||
except RuntimeError as e:
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print('>> Could not generate image.')
|
||||
except KeyboardInterrupt:
|
||||
if catch_interrupts:
|
||||
print('**Interrupted** Partial results will be returned.')
|
||||
else:
|
||||
raise KeyboardInterrupt
|
||||
except RuntimeError as e:
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print('>> Could not generate image.')
|
||||
|
||||
toc = time.time()
|
||||
print('>> Usage stats:')
|
||||
@ -553,7 +554,8 @@ class Generate:
|
||||
image = Image.open(image_path)
|
||||
|
||||
# used by multiple postfixers
|
||||
uc, c = get_uc_and_c(
|
||||
# todo: cross-attention control
|
||||
uc, c, _ = get_uc_and_c_and_ec(
|
||||
prompt, model =self.model,
|
||||
skip_normalize=opt.skip_normalize,
|
||||
log_tokens =opt.log_tokenization
|
||||
|
@ -92,7 +92,7 @@ import copy
|
||||
import base64
|
||||
import functools
|
||||
import ldm.invoke.pngwriter
|
||||
from ldm.invoke.conditioning import split_weighted_subprompts
|
||||
from ldm.invoke.prompt_parser import split_weighted_subprompts
|
||||
|
||||
SAMPLER_CHOICES = [
|
||||
'ddim',
|
||||
|
@ -4,107 +4,166 @@ weighted subprompts.
|
||||
|
||||
Useful function exports:
|
||||
|
||||
get_uc_and_c() get the conditioned and unconditioned latent
|
||||
get_uc_and_c_and_ec() get the conditioned and unconditioned latent, and edited conditioning if we're doing cross-attention control
|
||||
split_weighted_subpromopts() split subprompts, normalize and weight them
|
||||
log_tokenization() print out colour-coded tokens and warn if truncated
|
||||
|
||||
'''
|
||||
import re
|
||||
from difflib import SequenceMatcher
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
def get_uc_and_c(prompt, model, log_tokens=False, skip_normalize=False):
|
||||
from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \
|
||||
CrossAttentionControlledFragment, CrossAttentionControlSubstitute, CrossAttentionControlAppend, Fragment
|
||||
from ..models.diffusion.cross_attention_control import CrossAttentionControl
|
||||
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder
|
||||
|
||||
|
||||
def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_normalize=False):
|
||||
|
||||
# Extract Unconditioned Words From Prompt
|
||||
unconditioned_words = ''
|
||||
unconditional_regex = r'\[(.*?)\]'
|
||||
unconditionals = re.findall(unconditional_regex, prompt)
|
||||
unconditionals = re.findall(unconditional_regex, prompt_string_uncleaned)
|
||||
|
||||
if len(unconditionals) > 0:
|
||||
unconditioned_words = ' '.join(unconditionals)
|
||||
|
||||
# Remove Unconditioned Words From Prompt
|
||||
unconditional_regex_compile = re.compile(unconditional_regex)
|
||||
clean_prompt = unconditional_regex_compile.sub(' ', prompt)
|
||||
prompt = re.sub(' +', ' ', clean_prompt)
|
||||
clean_prompt = unconditional_regex_compile.sub(' ', prompt_string_uncleaned)
|
||||
prompt_string_cleaned = re.sub(' +', ' ', clean_prompt)
|
||||
else:
|
||||
prompt_string_cleaned = prompt_string_uncleaned
|
||||
|
||||
uc = model.get_learned_conditioning([unconditioned_words])
|
||||
pp = PromptParser()
|
||||
|
||||
# get weighted sub-prompts
|
||||
weighted_subprompts = split_weighted_subprompts(
|
||||
prompt, skip_normalize
|
||||
parsed_prompt: Union[FlattenedPrompt, Blend] = None
|
||||
legacy_blend: Blend = pp.parse_legacy_blend(prompt_string_cleaned)
|
||||
if legacy_blend is not None:
|
||||
parsed_prompt = legacy_blend
|
||||
else:
|
||||
# we don't support conjunctions for now
|
||||
parsed_prompt = pp.parse_conjunction(prompt_string_cleaned).prompts[0]
|
||||
|
||||
parsed_negative_prompt: FlattenedPrompt = pp.parse_conjunction(unconditioned_words).prompts[0]
|
||||
print(f">> Parsed prompt to {parsed_prompt}")
|
||||
|
||||
conditioning = None
|
||||
cac_args:CrossAttentionControl.Arguments = None
|
||||
|
||||
if type(parsed_prompt) is Blend:
|
||||
blend: Blend = parsed_prompt
|
||||
embeddings_to_blend = None
|
||||
for flattened_prompt in blend.prompts:
|
||||
this_embedding, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt, log_tokens=log_tokens)
|
||||
embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat(
|
||||
(embeddings_to_blend, this_embedding))
|
||||
conditioning = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0),
|
||||
blend.weights,
|
||||
normalize=blend.normalize_weights)
|
||||
else:
|
||||
flattened_prompt: FlattenedPrompt = parsed_prompt
|
||||
wants_cross_attention_control = type(flattened_prompt) is not Blend \
|
||||
and any([issubclass(type(x), CrossAttentionControlledFragment) for x in flattened_prompt.children])
|
||||
if wants_cross_attention_control:
|
||||
original_prompt = FlattenedPrompt()
|
||||
edited_prompt = FlattenedPrompt()
|
||||
# for name, a0, a1, b0, b1 in edit_opcodes: only name == 'equal' is currently parsed
|
||||
original_token_count = 0
|
||||
edited_token_count = 0
|
||||
edit_opcodes = []
|
||||
edit_options = []
|
||||
for fragment in flattened_prompt.children:
|
||||
if type(fragment) is CrossAttentionControlSubstitute:
|
||||
original_prompt.append(fragment.original)
|
||||
edited_prompt.append(fragment.edited)
|
||||
|
||||
to_replace_token_count = get_tokens_length(model, fragment.original)
|
||||
replacement_token_count = get_tokens_length(model, fragment.edited)
|
||||
edit_opcodes.append(('replace',
|
||||
original_token_count, original_token_count + to_replace_token_count,
|
||||
edited_token_count, edited_token_count + replacement_token_count
|
||||
))
|
||||
original_token_count += to_replace_token_count
|
||||
edited_token_count += replacement_token_count
|
||||
edit_options.append(fragment.options)
|
||||
#elif type(fragment) is CrossAttentionControlAppend:
|
||||
# edited_prompt.append(fragment.fragment)
|
||||
else:
|
||||
# regular fragment
|
||||
original_prompt.append(fragment)
|
||||
edited_prompt.append(fragment)
|
||||
|
||||
count = get_tokens_length(model, [fragment])
|
||||
edit_opcodes.append(('equal', original_token_count, original_token_count+count, edited_token_count, edited_token_count+count))
|
||||
edit_options.append(None)
|
||||
original_token_count += count
|
||||
edited_token_count += count
|
||||
original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, original_prompt, log_tokens=log_tokens)
|
||||
# naïvely building a single edited_embeddings like this disregards the effects of changing the absolute location of
|
||||
# subsequent tokens when there is >1 edit and earlier edits change the total token count.
|
||||
# eg "a cat.swap(smiling dog, s_start=0.5) eating a hotdog.swap(pizza)" - when the 'pizza' edit is active but the
|
||||
# 'cat' edit is not, the 'pizza' feature vector will nevertheless be affected by the introduction of the extra
|
||||
# token 'smiling' in the inactive 'cat' edit.
|
||||
# todo: build multiple edited_embeddings, one for each edit, and pass just the edited fragments through to the CrossAttentionControl functions
|
||||
edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, edited_prompt, log_tokens=log_tokens)
|
||||
|
||||
conditioning = original_embeddings
|
||||
edited_conditioning = edited_embeddings
|
||||
print('got edit_opcodes', edit_opcodes, 'options', edit_options)
|
||||
cac_args = CrossAttentionControl.Arguments(
|
||||
edited_conditioning = edited_conditioning,
|
||||
edit_opcodes = edit_opcodes,
|
||||
edit_options = edit_options
|
||||
)
|
||||
else:
|
||||
conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt, log_tokens=log_tokens)
|
||||
|
||||
|
||||
unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, parsed_negative_prompt, log_tokens=log_tokens)
|
||||
return (
|
||||
unconditioning, conditioning, InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||
cross_attention_control_args=cac_args
|
||||
)
|
||||
)
|
||||
|
||||
if len(weighted_subprompts) > 1:
|
||||
# i dont know if this is correct.. but it works
|
||||
c = torch.zeros_like(uc)
|
||||
# normalize each "sub prompt" and add it
|
||||
for subprompt, weight in weighted_subprompts:
|
||||
log_tokenization(subprompt, model, log_tokens, weight)
|
||||
c = torch.add(
|
||||
c,
|
||||
model.get_learned_conditioning([subprompt]),
|
||||
alpha=weight,
|
||||
)
|
||||
else: # just standard 1 prompt
|
||||
log_tokenization(prompt, model, log_tokens, 1)
|
||||
c = model.get_learned_conditioning([prompt])
|
||||
uc = model.get_learned_conditioning([unconditioned_words])
|
||||
return (uc, c)
|
||||
|
||||
def split_weighted_subprompts(text, skip_normalize=False)->list:
|
||||
"""
|
||||
grabs all text up to the first occurrence of ':'
|
||||
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
|
||||
if ':' has no value defined, defaults to 1.0
|
||||
repeats until no text remaining
|
||||
"""
|
||||
prompt_parser = re.compile("""
|
||||
(?P<prompt> # capture group for 'prompt'
|
||||
(?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:'
|
||||
) # end 'prompt'
|
||||
(?: # non-capture group
|
||||
:+ # match one or more ':' characters
|
||||
(?P<weight> # capture group for 'weight'
|
||||
-?\d+(?:\.\d+)? # match positive or negative integer or decimal number
|
||||
)? # end weight capture group, make optional
|
||||
\s* # strip spaces after weight
|
||||
| # OR
|
||||
$ # else, if no ':' then match end of line
|
||||
) # end non-capture group
|
||||
""", re.VERBOSE)
|
||||
parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float(
|
||||
match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)]
|
||||
if skip_normalize:
|
||||
return parsed_prompts
|
||||
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
|
||||
if weight_sum == 0:
|
||||
print(
|
||||
"Warning: Subprompt weights add up to zero. Discarding and using even weights instead.")
|
||||
equal_weight = 1 / max(len(parsed_prompts), 1)
|
||||
return [(x[0], equal_weight) for x in parsed_prompts]
|
||||
return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
|
||||
def build_token_edit_opcodes(original_tokens, edited_tokens):
|
||||
original_tokens = original_tokens.cpu().numpy()[0]
|
||||
edited_tokens = edited_tokens.cpu().numpy()[0]
|
||||
|
||||
return SequenceMatcher(None, original_tokens, edited_tokens).get_opcodes()
|
||||
|
||||
def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt, log_tokens: bool=False):
|
||||
if type(flattened_prompt) is not FlattenedPrompt:
|
||||
raise Exception(f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead")
|
||||
fragments = [x.text for x in flattened_prompt.children]
|
||||
weights = [x.weight for x in flattened_prompt.children]
|
||||
embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True, fragment_weights=[weights])
|
||||
if not flattened_prompt.is_empty and log_tokens:
|
||||
start_token = model.cond_stage_model.tokenizer.bos_token_id
|
||||
end_token = model.cond_stage_model.tokenizer.eos_token_id
|
||||
tokens_list = tokens[0].tolist()
|
||||
if tokens_list[0] == start_token:
|
||||
tokens_list[0] = '<start>'
|
||||
try:
|
||||
first_end_token_index = tokens_list.index(end_token)
|
||||
tokens_list[first_end_token_index] = '<end>'
|
||||
tokens_list = tokens_list[:first_end_token_index+1]
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
print(f">> Prompt fragments {fragments}, tokenized to \n{tokens_list}")
|
||||
|
||||
return embeddings, tokens
|
||||
|
||||
def get_tokens_length(model, fragments: list[Fragment]):
|
||||
fragment_texts = [x.text for x in fragments]
|
||||
tokens = model.cond_stage_model.get_tokens(fragment_texts, include_start_and_end_markers=False)
|
||||
return sum([len(x) for x in tokens])
|
||||
|
||||
|
||||
# shows how the prompt is tokenized
|
||||
# usually tokens have '</w>' to indicate end-of-word,
|
||||
# but for readability it has been replaced with ' '
|
||||
def log_tokenization(text, model, log=False, weight=1):
|
||||
if not log:
|
||||
return
|
||||
tokens = model.cond_stage_model.tokenizer._tokenize(text)
|
||||
tokenized = ""
|
||||
discarded = ""
|
||||
usedTokens = 0
|
||||
totalTokens = len(tokens)
|
||||
for i in range(0, totalTokens):
|
||||
token = tokens[i].replace('</w>', ' ')
|
||||
# alternate color
|
||||
s = (usedTokens % 6) + 1
|
||||
if i < model.cond_stage_model.max_length:
|
||||
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
|
||||
usedTokens += 1
|
||||
else: # over max token length
|
||||
discarded = discarded + f"\x1b[0;3{s};40m{token}"
|
||||
print(f"\n>> Tokens ({usedTokens}), Weight ({weight:.2f}):\n{tokenized}\x1b[0m")
|
||||
if discarded != "":
|
||||
print(
|
||||
f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m"
|
||||
)
|
||||
|
@ -10,6 +10,7 @@ from PIL import Image
|
||||
from ldm.invoke.devices import choose_autocast
|
||||
from ldm.invoke.generator.base import Generator
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
|
||||
class Img2Img(Generator):
|
||||
def __init__(self, model, precision):
|
||||
@ -38,7 +39,7 @@ class Img2Img(Generator):
|
||||
) # move to latent space
|
||||
|
||||
t_enc = int(strength * steps)
|
||||
uc, c = conditioning
|
||||
uc, c, extra_conditioning_info = conditioning
|
||||
|
||||
def make_image(x_T):
|
||||
# encode (scaled latent)
|
||||
@ -56,6 +57,8 @@ class Img2Img(Generator):
|
||||
unconditional_guidance_scale=cfg_scale,
|
||||
unconditional_conditioning=uc,
|
||||
init_latent = self.init_latent, # changes how noising is performed in ksampler
|
||||
extra_conditioning_info = extra_conditioning_info,
|
||||
all_timesteps_count = steps
|
||||
)
|
||||
|
||||
return self.sample_to_image(samples)
|
||||
|
@ -73,7 +73,8 @@ class Inpaint(Img2Img):
|
||||
) # move to latent space
|
||||
|
||||
t_enc = int(strength * steps)
|
||||
uc, c = conditioning
|
||||
# todo: support cross-attention control
|
||||
uc, c, _ = conditioning
|
||||
|
||||
print(f">> target t_enc is {t_enc} steps")
|
||||
|
||||
|
@ -5,6 +5,8 @@ ldm.invoke.generator.txt2img inherits from ldm.invoke.generator
|
||||
import torch
|
||||
import numpy as np
|
||||
from ldm.invoke.generator.base import Generator
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
|
||||
|
||||
class Txt2Img(Generator):
|
||||
def __init__(self, model, precision):
|
||||
@ -19,7 +21,7 @@ class Txt2Img(Generator):
|
||||
kwargs are 'width' and 'height'
|
||||
"""
|
||||
self.perlin = perlin
|
||||
uc, c = conditioning
|
||||
uc, c, extra_conditioning_info = conditioning
|
||||
|
||||
@torch.no_grad()
|
||||
def make_image(x_T):
|
||||
@ -43,6 +45,7 @@ class Txt2Img(Generator):
|
||||
verbose = False,
|
||||
unconditional_guidance_scale = cfg_scale,
|
||||
unconditional_conditioning = uc,
|
||||
extra_conditioning_info = extra_conditioning_info,
|
||||
eta = ddim_eta,
|
||||
img_callback = step_callback,
|
||||
threshold = threshold,
|
||||
|
@ -7,6 +7,7 @@ import numpy as np
|
||||
import math
|
||||
from ldm.invoke.generator.base import Generator
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
|
||||
|
||||
class Txt2Img2Img(Generator):
|
||||
@ -22,7 +23,7 @@ class Txt2Img2Img(Generator):
|
||||
Return value depends on the seed at the time you call it
|
||||
kwargs are 'width' and 'height'
|
||||
"""
|
||||
uc, c = conditioning
|
||||
uc, c, extra_conditioning_info = conditioning
|
||||
|
||||
@torch.no_grad()
|
||||
def make_image(x_T):
|
||||
@ -60,7 +61,8 @@ class Txt2Img2Img(Generator):
|
||||
unconditional_guidance_scale = cfg_scale,
|
||||
unconditional_conditioning = uc,
|
||||
eta = ddim_eta,
|
||||
img_callback = step_callback
|
||||
img_callback = step_callback,
|
||||
extra_conditioning_info = extra_conditioning_info
|
||||
)
|
||||
|
||||
print(
|
||||
@ -94,6 +96,8 @@ class Txt2Img2Img(Generator):
|
||||
img_callback = step_callback,
|
||||
unconditional_guidance_scale=cfg_scale,
|
||||
unconditional_conditioning=uc,
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
all_timesteps_count=steps
|
||||
)
|
||||
|
||||
if self.free_gpu_mem:
|
||||
|
680
ldm/invoke/prompt_parser.py
Normal file
680
ldm/invoke/prompt_parser.py
Normal file
@ -0,0 +1,680 @@
|
||||
import string
|
||||
from typing import Union, Optional
|
||||
import re
|
||||
import pyparsing as pp
|
||||
|
||||
class Prompt():
|
||||
"""
|
||||
Mid-level structure for storing the tree-like result of parsing a prompt. A Prompt may not represent the whole of
|
||||
the singular user-defined "prompt string" (although it can) - for example, if the user specifies a Blend, the objects
|
||||
that are to be blended together are stored individuall as Prompt objects.
|
||||
|
||||
Nesting makes this object not suitable for directly tokenizing; instead call flatten() on the containing Conjunction
|
||||
to produce a FlattenedPrompt.
|
||||
"""
|
||||
def __init__(self, parts: list):
|
||||
for c in parts:
|
||||
if type(c) is not Attention and not issubclass(type(c), BaseFragment) and type(c) is not pp.ParseResults:
|
||||
raise PromptParser.ParsingException(f"Prompt cannot contain {type(c).__name__} {c}, only {BaseFragment.__subclasses__()} are allowed")
|
||||
self.children = parts
|
||||
def __repr__(self):
|
||||
return f"Prompt:{self.children}"
|
||||
def __eq__(self, other):
|
||||
return type(other) is Prompt and other.children == self.children
|
||||
|
||||
class BaseFragment:
|
||||
pass
|
||||
|
||||
class FlattenedPrompt():
|
||||
"""
|
||||
A Prompt that has been passed through flatten(). Its children can be readily tokenized.
|
||||
"""
|
||||
def __init__(self, parts: list=[]):
|
||||
self.children = []
|
||||
for part in parts:
|
||||
self.append(part)
|
||||
|
||||
def append(self, fragment: Union[list, BaseFragment, tuple]):
|
||||
# verify type correctness
|
||||
if type(fragment) is list:
|
||||
for x in fragment:
|
||||
self.append(x)
|
||||
elif issubclass(type(fragment), BaseFragment):
|
||||
self.children.append(fragment)
|
||||
elif type(fragment) is tuple:
|
||||
# upgrade tuples to Fragments
|
||||
if type(fragment[0]) is not str or (type(fragment[1]) is not float and type(fragment[1]) is not int):
|
||||
raise PromptParser.ParsingException(
|
||||
f"FlattenedPrompt cannot contain {fragment}, only Fragments or (str, float) tuples are allowed")
|
||||
self.children.append(Fragment(fragment[0], fragment[1]))
|
||||
else:
|
||||
raise PromptParser.ParsingException(
|
||||
f"FlattenedPrompt cannot contain {fragment}, only Fragments or (str, float) tuples are allowed")
|
||||
|
||||
@property
|
||||
def is_empty(self):
|
||||
return len(self.children) == 0 or \
|
||||
(len(self.children) == 1 and len(self.children[0].text) == 0)
|
||||
|
||||
def __repr__(self):
|
||||
return f"FlattenedPrompt:{self.children}"
|
||||
def __eq__(self, other):
|
||||
return type(other) is FlattenedPrompt and other.children == self.children
|
||||
|
||||
|
||||
class Fragment(BaseFragment):
|
||||
"""
|
||||
A Fragment is a chunk of plain text and an optional weight. The text should be passed as-is to the CLIP tokenizer.
|
||||
"""
|
||||
def __init__(self, text: str, weight: float=1):
|
||||
assert(type(text) is str)
|
||||
if '\\"' in text or '\\(' in text or '\\)' in text:
|
||||
#print("Fragment converting escaped \( \) \\\" into ( ) \"")
|
||||
text = text.replace('\\(', '(').replace('\\)', ')').replace('\\"', '"')
|
||||
self.text = text
|
||||
self.weight = float(weight)
|
||||
|
||||
def __repr__(self):
|
||||
return "Fragment:'"+self.text+"'@"+str(self.weight)
|
||||
def __eq__(self, other):
|
||||
return type(other) is Fragment \
|
||||
and other.text == self.text \
|
||||
and other.weight == self.weight
|
||||
|
||||
class Attention():
|
||||
"""
|
||||
Nestable weight control for fragments. Each object in the children array may in turn be an Attention object;
|
||||
weights should be considered to accumulate as the tree is traversed to deeper levels of nesting.
|
||||
|
||||
Do not traverse directly; instead obtain a FlattenedPrompt by calling Flatten() on a top-level Conjunction object.
|
||||
"""
|
||||
def __init__(self, weight: float, children: list):
|
||||
self.weight = weight
|
||||
self.children = children
|
||||
#print(f"A: requested attention '{children}' to {weight}")
|
||||
|
||||
def __repr__(self):
|
||||
return f"Attention:'{self.children}' @ {self.weight}"
|
||||
def __eq__(self, other):
|
||||
return type(other) is Attention and other.weight == self.weight and other.fragment == self.fragment
|
||||
|
||||
class CrossAttentionControlledFragment(BaseFragment):
|
||||
pass
|
||||
|
||||
class CrossAttentionControlSubstitute(CrossAttentionControlledFragment):
|
||||
"""
|
||||
A Cross-Attention Controlled ('prompt2prompt') fragment, for use inside a Prompt, Attention, or FlattenedPrompt.
|
||||
Representing an "original" word sequence that supplies feature vectors for an initial diffusion operation, and an
|
||||
"edited" word sequence, to which the attention maps produced by the "original" word sequence are applied. Intuitively,
|
||||
the result should be an "edited" image that looks like the "original" image with concepts swapped.
|
||||
|
||||
eg "a cat sitting on a car" (original) -> "a smiling dog sitting on a car" (edited): the edited image should look
|
||||
almost exactly the same as the original, but with a smiling dog rendered in place of the cat. The
|
||||
CrossAttentionControlSubstitute object representing this swap may be confined to the tokens being swapped:
|
||||
CrossAttentionControlSubstitute(original=[Fragment('cat')], edited=[Fragment('dog')])
|
||||
or it may represent a larger portion of the token sequence:
|
||||
CrossAttentionControlSubstitute(original=[Fragment('a cat sitting on a car')],
|
||||
edited=[Fragment('a smiling dog sitting on a car')])
|
||||
|
||||
In either case expect it to be embedded in a Prompt or FlattenedPrompt:
|
||||
FlattenedPrompt([
|
||||
Fragment('a'),
|
||||
CrossAttentionControlSubstitute(original=[Fragment('cat')], edited=[Fragment('dog')]),
|
||||
Fragment('sitting on a car')
|
||||
])
|
||||
"""
|
||||
def __init__(self, original: Union[Fragment, list], edited: Union[Fragment, list], options: dict=None):
|
||||
self.original = original
|
||||
self.edited = edited
|
||||
|
||||
default_options = {
|
||||
's_start': 0.0,
|
||||
's_end': 0.206, # ~= shape_freedom=0.5
|
||||
't_start': 0.0,
|
||||
't_end': 1.0
|
||||
}
|
||||
merged_options = default_options
|
||||
if options is not None:
|
||||
shape_freedom = options.pop('shape_freedom', None)
|
||||
if shape_freedom is not None:
|
||||
# high shape freedom = SD can do what it wants with the shape of the object
|
||||
# high shape freedom => s_end = 0
|
||||
# low shape freedom => s_end = 1
|
||||
# shape freedom is in a "linear" space, while noticeable changes to s_end are typically closer around 0,
|
||||
# and there is very little perceptible difference as s_end increases above 0.5
|
||||
# so for shape_freedom = 0.5 we probably want s_end to be 0.2
|
||||
# -> cube root and subtract from 1.0
|
||||
merged_options['s_end'] = 1.0 - shape_freedom ** (1. / 3.)
|
||||
print('converted shape_freedom argument to', merged_options)
|
||||
merged_options.update(options)
|
||||
|
||||
self.options = merged_options
|
||||
|
||||
def __repr__(self):
|
||||
return f"CrossAttentionControlSubstitute:({self.original}->{self.edited} ({self.options})"
|
||||
def __eq__(self, other):
|
||||
return type(other) is CrossAttentionControlSubstitute \
|
||||
and other.original == self.original \
|
||||
and other.edited == self.edited \
|
||||
and other.options == self.options
|
||||
|
||||
|
||||
class CrossAttentionControlAppend(CrossAttentionControlledFragment):
|
||||
def __init__(self, fragment: Fragment):
|
||||
self.fragment = fragment
|
||||
def __repr__(self):
|
||||
return "CrossAttentionControlAppend:",self.fragment
|
||||
def __eq__(self, other):
|
||||
return type(other) is CrossAttentionControlAppend \
|
||||
and other.fragment == self.fragment
|
||||
|
||||
|
||||
|
||||
class Conjunction():
|
||||
"""
|
||||
Storage for one or more Prompts or Blends, each of which is to be separately diffused and then the results merged
|
||||
by weighted sum in latent space.
|
||||
"""
|
||||
def __init__(self, prompts: list, weights: list = None):
|
||||
# force everything to be a Prompt
|
||||
#print("making conjunction with", parts)
|
||||
self.prompts = [x if (type(x) is Prompt
|
||||
or type(x) is Blend
|
||||
or type(x) is FlattenedPrompt)
|
||||
else Prompt(x) for x in prompts]
|
||||
self.weights = [1.0]*len(self.prompts) if weights is None else list(weights)
|
||||
if len(self.weights) != len(self.prompts):
|
||||
raise PromptParser.ParsingException(f"while parsing Conjunction: mismatched parts/weights counts {prompts}, {weights}")
|
||||
self.type = 'AND'
|
||||
|
||||
def __repr__(self):
|
||||
return f"Conjunction:{self.prompts} | weights {self.weights}"
|
||||
def __eq__(self, other):
|
||||
return type(other) is Conjunction \
|
||||
and other.prompts == self.prompts \
|
||||
and other.weights == self.weights
|
||||
|
||||
|
||||
class Blend():
|
||||
"""
|
||||
Stores a Blend of multiple Prompts. To apply, build feature vectors for each of the child Prompts and then perform a
|
||||
weighted blend of the feature vectors to produce a single feature vector that is effectively a lerp between the
|
||||
Prompts.
|
||||
"""
|
||||
def __init__(self, prompts: list, weights: list[float], normalize_weights: bool=True):
|
||||
#print("making Blend with prompts", prompts, "and weights", weights)
|
||||
if len(prompts) != len(weights):
|
||||
raise PromptParser.ParsingException(f"while parsing Blend: mismatched prompts/weights counts {prompts}, {weights}")
|
||||
for c in prompts:
|
||||
if type(c) is not Prompt and type(c) is not FlattenedPrompt:
|
||||
raise(PromptParser.ParsingException(f"{type(c)} cannot be added to a Blend, only Prompts or FlattenedPrompts"))
|
||||
# upcast all lists to Prompt objects
|
||||
self.prompts = [x if (type(x) is Prompt or type(x) is FlattenedPrompt)
|
||||
else Prompt(x) for x in prompts]
|
||||
self.prompts = prompts
|
||||
self.weights = weights
|
||||
self.normalize_weights = normalize_weights
|
||||
|
||||
def __repr__(self):
|
||||
return f"Blend:{self.prompts} | weights {' ' if self.normalize_weights else '(non-normalized) '}{self.weights}"
|
||||
def __eq__(self, other):
|
||||
return other.__repr__() == self.__repr__()
|
||||
|
||||
|
||||
class PromptParser():
|
||||
|
||||
class ParsingException(Exception):
|
||||
pass
|
||||
|
||||
def __init__(self, attention_plus_base=1.1, attention_minus_base=0.9):
|
||||
|
||||
self.conjunction, self.prompt = build_parser_syntax(attention_plus_base, attention_minus_base)
|
||||
|
||||
|
||||
def parse_conjunction(self, prompt: str) -> Conjunction:
|
||||
'''
|
||||
:param prompt: The prompt string to parse
|
||||
:return: a Conjunction representing the parsed results.
|
||||
'''
|
||||
#print(f"!!parsing '{prompt}'")
|
||||
|
||||
if len(prompt.strip()) == 0:
|
||||
return Conjunction(prompts=[FlattenedPrompt([('', 1.0)])], weights=[1.0])
|
||||
|
||||
root = self.conjunction.parse_string(prompt)
|
||||
#print(f"'{prompt}' parsed to root", root)
|
||||
#fused = fuse_fragments(parts)
|
||||
#print("fused to", fused)
|
||||
|
||||
return self.flatten(root[0])
|
||||
|
||||
def parse_legacy_blend(self, text: str) -> Optional[Blend]:
|
||||
weighted_subprompts = split_weighted_subprompts(text, skip_normalize=False)
|
||||
if len(weighted_subprompts) == 1:
|
||||
return None
|
||||
strings = [x[0] for x in weighted_subprompts]
|
||||
weights = [x[1] for x in weighted_subprompts]
|
||||
|
||||
parsed_conjunctions = [self.parse_conjunction(x) for x in strings]
|
||||
flattened_prompts = [x.prompts[0] for x in parsed_conjunctions]
|
||||
|
||||
return Blend(prompts=flattened_prompts, weights=weights, normalize_weights=True)
|
||||
|
||||
|
||||
def flatten(self, root: Conjunction) -> Conjunction:
|
||||
"""
|
||||
Flattening a Conjunction traverses all of the nested tree-like structures in each of its Prompts or Blends,
|
||||
producing from each of these walks a linear sequence of Fragment or CrossAttentionControlSubstitute objects
|
||||
that can be readily tokenized without the need to walk a complex tree structure.
|
||||
|
||||
:param root: The Conjunction to flatten.
|
||||
:return: A Conjunction containing the result of flattening each of the prompts in the passed-in root.
|
||||
"""
|
||||
|
||||
#print("flattening", root)
|
||||
|
||||
def fuse_fragments(items):
|
||||
# print("fusing fragments in ", items)
|
||||
result = []
|
||||
for x in items:
|
||||
if type(x) is CrossAttentionControlSubstitute:
|
||||
original_fused = fuse_fragments(x.original)
|
||||
edited_fused = fuse_fragments(x.edited)
|
||||
result.append(CrossAttentionControlSubstitute(original_fused, edited_fused, options=x.options))
|
||||
else:
|
||||
last_weight = result[-1].weight \
|
||||
if (len(result) > 0 and not issubclass(type(result[-1]), CrossAttentionControlledFragment)) \
|
||||
else None
|
||||
this_text = x.text
|
||||
this_weight = x.weight
|
||||
if last_weight is not None and last_weight == this_weight:
|
||||
last_text = result[-1].text
|
||||
result[-1] = Fragment(last_text + ' ' + this_text, last_weight)
|
||||
else:
|
||||
result.append(x)
|
||||
return result
|
||||
|
||||
def flatten_internal(node, weight_scale, results, prefix):
|
||||
#print(prefix + "flattening", node, "...")
|
||||
if type(node) is pp.ParseResults:
|
||||
for x in node:
|
||||
results = flatten_internal(x, weight_scale, results, prefix+' pr ')
|
||||
#print(prefix, " ParseResults expanded, results is now", results)
|
||||
elif type(node) is Attention:
|
||||
# if node.weight < 1:
|
||||
# todo: inject a blend when flattening attention with weight <1"
|
||||
for index,c in enumerate(node.children):
|
||||
results = flatten_internal(c, weight_scale * node.weight, results, prefix + f" att{index} ")
|
||||
elif type(node) is Fragment:
|
||||
results += [Fragment(node.text, node.weight*weight_scale)]
|
||||
elif type(node) is CrossAttentionControlSubstitute:
|
||||
original = flatten_internal(node.original, weight_scale, [], prefix + ' CAo ')
|
||||
edited = flatten_internal(node.edited, weight_scale, [], prefix + ' CAe ')
|
||||
results += [CrossAttentionControlSubstitute(original, edited, options=node.options)]
|
||||
elif type(node) is Blend:
|
||||
flattened_subprompts = []
|
||||
#print(" flattening blend with prompts", node.prompts, "weights", node.weights)
|
||||
for prompt in node.prompts:
|
||||
# prompt is a list
|
||||
flattened_subprompts = flatten_internal(prompt, weight_scale, flattened_subprompts, prefix+'B ')
|
||||
results += [Blend(prompts=flattened_subprompts, weights=node.weights, normalize_weights=node.normalize_weights)]
|
||||
elif type(node) is Prompt:
|
||||
#print(prefix + "about to flatten Prompt with children", node.children)
|
||||
flattened_prompt = []
|
||||
for child in node.children:
|
||||
flattened_prompt = flatten_internal(child, weight_scale, flattened_prompt, prefix+'P ')
|
||||
results += [FlattenedPrompt(parts=fuse_fragments(flattened_prompt))]
|
||||
#print(prefix + "after flattening Prompt, results is", results)
|
||||
else:
|
||||
raise PromptParser.ParsingException(f"unhandled node type {type(node)} when flattening {node}")
|
||||
#print(prefix + "-> after flattening", type(node).__name__, "results is", results)
|
||||
return results
|
||||
|
||||
|
||||
flattened_parts = []
|
||||
for part in root.prompts:
|
||||
flattened_parts += flatten_internal(part, 1.0, [], ' C| ')
|
||||
|
||||
#print("flattened to", flattened_parts)
|
||||
|
||||
weights = root.weights
|
||||
return Conjunction(flattened_parts, weights)
|
||||
|
||||
|
||||
|
||||
def build_parser_syntax(attention_plus_base: float, attention_minus_base: float):
|
||||
|
||||
lparen = pp.Literal("(").suppress()
|
||||
rparen = pp.Literal(")").suppress()
|
||||
quotes = pp.Literal('"').suppress()
|
||||
comma = pp.Literal(",").suppress()
|
||||
|
||||
# accepts int or float notation, always maps to float
|
||||
number = pp.pyparsing_common.real | \
|
||||
pp.Combine(pp.Optional("-")+pp.Word(pp.nums)).set_parse_action(pp.token_map(float))
|
||||
greedy_word = pp.Word(pp.printables, exclude_chars=string.whitespace).set_name('greedy_word')
|
||||
|
||||
attention = pp.Forward()
|
||||
quoted_fragment = pp.Forward()
|
||||
parenthesized_fragment = pp.Forward()
|
||||
cross_attention_substitute = pp.Forward()
|
||||
prompt_part = pp.Forward()
|
||||
|
||||
def make_text_fragment(x):
|
||||
#print("### making fragment for", x)
|
||||
if type(x) is str:
|
||||
return Fragment(x)
|
||||
elif type(x) is pp.ParseResults or type(x) is list:
|
||||
#print(f'converting {type(x).__name__} to Fragment')
|
||||
return Fragment(' '.join([s for s in x]))
|
||||
else:
|
||||
raise PromptParser.ParsingException("Cannot make fragment from " + str(x))
|
||||
|
||||
def build_escaped_word_parser(escaped_chars_to_ignore: str):
|
||||
terms = []
|
||||
for c in escaped_chars_to_ignore:
|
||||
terms.append(pp.Literal('\\'+c))
|
||||
terms.append(
|
||||
#pp.CharsNotIn(string.whitespace + escaped_chars_to_ignore, exact=1)
|
||||
pp.Word(pp.printables, exclude_chars=string.whitespace + escaped_chars_to_ignore)
|
||||
)
|
||||
return pp.Combine(pp.OneOrMore(
|
||||
pp.MatchFirst(terms)
|
||||
))
|
||||
|
||||
def build_escaped_word_parser_charbychar(escaped_chars_to_ignore: str):
|
||||
escapes = []
|
||||
for c in escaped_chars_to_ignore:
|
||||
escapes.append(pp.Literal('\\'+c))
|
||||
return pp.Combine(pp.OneOrMore(
|
||||
pp.MatchFirst(escapes + [pp.CharsNotIn(
|
||||
string.whitespace + escaped_chars_to_ignore,
|
||||
exact=1
|
||||
)])
|
||||
))
|
||||
|
||||
|
||||
|
||||
def parse_fragment_str(x, in_quotes: bool=False, in_parens: bool=False):
|
||||
#print(f"parsing fragment string \"{x}\"")
|
||||
fragment_string = x[0]
|
||||
if len(fragment_string.strip()) == 0:
|
||||
return Fragment('')
|
||||
|
||||
if in_quotes:
|
||||
# escape unescaped quotes
|
||||
fragment_string = fragment_string.replace('"', '\\"')
|
||||
|
||||
#fragment_parser = pp.Group(pp.OneOrMore(attention | cross_attention_substitute | (greedy_word.set_parse_action(make_text_fragment))))
|
||||
result = pp.Group(pp.MatchFirst([
|
||||
pp.OneOrMore(prompt_part | quoted_fragment),
|
||||
pp.Empty().set_parse_action(make_text_fragment) + pp.StringEnd()
|
||||
])).set_name('rr').set_debug(False).parse_string(fragment_string)
|
||||
#result = (pp.OneOrMore(attention | unquoted_word) + pp.StringEnd()).parse_string(x[0])
|
||||
#print("parsed to", result)
|
||||
return result
|
||||
|
||||
quoted_fragment << pp.QuotedString(quote_char='"', esc_char=None, esc_quote='\\"')
|
||||
quoted_fragment.set_parse_action(lambda x: parse_fragment_str(x, in_quotes=True)).set_name('quoted_fragment')
|
||||
|
||||
escaped_quote = pp.Literal('\\"')#.set_parse_action(lambda x: '"')
|
||||
escaped_lparen = pp.Literal('\\(')#.set_parse_action(lambda x: '(')
|
||||
escaped_rparen = pp.Literal('\\)')#.set_parse_action(lambda x: ')')
|
||||
escaped_backslash = pp.Literal('\\\\')#.set_parse_action(lambda x: '"')
|
||||
|
||||
def not_ends_with_swap(x):
|
||||
#print("trying to match:", x)
|
||||
return not x[0].endswith('.swap')
|
||||
|
||||
unquoted_fragment = pp.Combine(pp.OneOrMore(
|
||||
escaped_rparen | escaped_lparen | escaped_quote | escaped_backslash |
|
||||
pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"()')))
|
||||
unquoted_fragment.set_parse_action(make_text_fragment).set_name('unquoted_fragment').set_debug(False)
|
||||
#print(unquoted_fragment.parse_string("cat.swap(dog)"))
|
||||
|
||||
parenthesized_fragment << pp.Or([
|
||||
(lparen + quoted_fragment.copy().set_parse_action(lambda x: parse_fragment_str(x, in_quotes=True)).set_debug(False) + rparen).set_name('-quoted_paren_internal').set_debug(False),
|
||||
(lparen + rparen).set_parse_action(lambda x: make_text_fragment('')).set_name('-()').set_debug(False),
|
||||
(lparen + pp.Combine(pp.OneOrMore(
|
||||
escaped_quote | escaped_lparen | escaped_rparen | escaped_backslash |
|
||||
pp.Word(pp.printables, exclude_chars=string.whitespace + '\\"()') |
|
||||
pp.Word(string.whitespace)
|
||||
)).set_name('--combined').set_parse_action(lambda x: parse_fragment_str(x, in_parens=True)).set_debug(False) + rparen)]).set_name('-unquoted_paren_internal').set_debug(False)
|
||||
parenthesized_fragment.set_name('parenthesized_fragment').set_debug(False)
|
||||
|
||||
debug_attention = False
|
||||
# attention control of the form (phrase)+ / (phrase)+ / (phrase)<weight>
|
||||
# phrase can be multiple words, can have multiple +/- signs to increase the effect or type a floating point or integer weight
|
||||
attention_with_parens = pp.Forward()
|
||||
attention_without_parens = pp.Forward()
|
||||
|
||||
attention_with_parens_foot = (number | pp.Word('+') | pp.Word('-'))\
|
||||
.set_name("attention_foot")\
|
||||
.set_debug(False)
|
||||
attention_with_parens <<= pp.Group(
|
||||
lparen +
|
||||
pp.ZeroOrMore(quoted_fragment | attention_with_parens | parenthesized_fragment | cross_attention_substitute | attention_without_parens |
|
||||
(pp.Empty() + build_escaped_word_parser_charbychar('()')).set_name('undecorated_word').set_debug(debug_attention)#.set_parse_action(lambda t: t[0])
|
||||
)
|
||||
+ rparen + attention_with_parens_foot)
|
||||
attention_with_parens.set_name('attention_with_parens').set_debug(debug_attention)
|
||||
|
||||
attention_without_parens_foot = pp.Or(pp.Word('+') | pp.Word('-')).set_name('attention_without_parens_foots')
|
||||
attention_without_parens <<= pp.Group(
|
||||
(quoted_fragment.copy().set_name('attention_quoted_fragment_without_parens').set_debug(debug_attention) + attention_without_parens_foot) |
|
||||
pp.Combine(build_escaped_word_parser_charbychar('()+-')).set_name('attention_word_without_parens').set_debug(debug_attention)#.set_parse_action(lambda x: print('escapéd', x))
|
||||
+ attention_without_parens_foot)#.leave_whitespace()
|
||||
attention_without_parens.set_name('attention_without_parens').set_debug(debug_attention)
|
||||
|
||||
|
||||
attention << pp.MatchFirst([attention_with_parens,
|
||||
attention_without_parens
|
||||
])
|
||||
attention.set_name('attention')
|
||||
|
||||
def make_attention(x):
|
||||
#print("entered make_attention with", x)
|
||||
children = x[0][:-1]
|
||||
weight_raw = x[0][-1]
|
||||
weight = 1.0
|
||||
if type(weight_raw) is float or type(weight_raw) is int:
|
||||
weight = weight_raw
|
||||
elif type(weight_raw) is str:
|
||||
base = attention_plus_base if weight_raw[0] == '+' else attention_minus_base
|
||||
weight = pow(base, len(weight_raw))
|
||||
|
||||
#print("making Attention from", children, "with weight", weight)
|
||||
|
||||
return Attention(weight=weight, children=[(Fragment(x) if type(x) is str else x) for x in children])
|
||||
|
||||
attention_with_parens.set_parse_action(make_attention)
|
||||
attention_without_parens.set_parse_action(make_attention)
|
||||
|
||||
#print("parsing test:", attention_with_parens.parse_string("mountain (man)1.1"))
|
||||
|
||||
# cross-attention control
|
||||
empty_string = ((lparen + rparen) |
|
||||
pp.Literal('""').suppress() |
|
||||
(lparen + pp.Literal('""').suppress() + rparen)
|
||||
).set_parse_action(lambda x: Fragment(""))
|
||||
empty_string.set_name('empty_string')
|
||||
|
||||
# cross attention control
|
||||
debug_cross_attention_control = False
|
||||
original_fragment = pp.Or([empty_string.set_debug(debug_cross_attention_control),
|
||||
quoted_fragment.set_debug(debug_cross_attention_control),
|
||||
parenthesized_fragment.set_debug(debug_cross_attention_control),
|
||||
pp.Word(pp.printables, exclude_chars=string.whitespace + '.').set_parse_action(make_text_fragment) + pp.FollowedBy(".swap")
|
||||
])
|
||||
# support keyword=number arguments
|
||||
cross_attention_option_keyword = pp.Or([pp.Keyword("s_start"), pp.Keyword("s_end"), pp.Keyword("t_start"), pp.Keyword("t_end"), pp.Keyword("shape_freedom")])
|
||||
cross_attention_option = pp.Group(cross_attention_option_keyword + pp.Literal("=").suppress() + number)
|
||||
edited_fragment = pp.MatchFirst([
|
||||
lparen +
|
||||
(quoted_fragment |
|
||||
pp.Group(pp.OneOrMore(pp.Word(pp.printables, exclude_chars=string.whitespace + ',').set_parse_action(make_text_fragment)))
|
||||
) +
|
||||
pp.Dict(pp.OneOrMore(comma + cross_attention_option)) +
|
||||
rparen,
|
||||
parenthesized_fragment
|
||||
])
|
||||
cross_attention_substitute << original_fragment + pp.Literal(".swap").suppress() + edited_fragment
|
||||
|
||||
original_fragment.set_name('original_fragment').set_debug(debug_cross_attention_control)
|
||||
edited_fragment.set_name('edited_fragment').set_debug(debug_cross_attention_control)
|
||||
cross_attention_substitute.set_name('cross_attention_substitute').set_debug(debug_cross_attention_control)
|
||||
|
||||
def make_cross_attention_substitute(x):
|
||||
#print("making cacs for", x[0], "->", x[1], "with options", x.as_dict())
|
||||
#if len(x>2):
|
||||
cacs = CrossAttentionControlSubstitute(x[0], x[1], options=x.as_dict())
|
||||
#print("made", cacs)
|
||||
return cacs
|
||||
cross_attention_substitute.set_parse_action(make_cross_attention_substitute)
|
||||
|
||||
|
||||
# simple fragments of text
|
||||
# use Or to match the longest
|
||||
prompt_part << pp.MatchFirst([
|
||||
cross_attention_substitute,
|
||||
attention,
|
||||
unquoted_fragment,
|
||||
lparen + unquoted_fragment + rparen # matches case where user has +(term) and just deletes the +
|
||||
])
|
||||
prompt_part.set_debug(False)
|
||||
prompt_part.set_name("prompt_part")
|
||||
|
||||
empty = (
|
||||
(lparen + pp.ZeroOrMore(pp.Word(string.whitespace)) + rparen) |
|
||||
(quotes + pp.ZeroOrMore(pp.Word(string.whitespace)) + quotes)).set_debug(False).set_name('empty')
|
||||
|
||||
# root prompt definition
|
||||
prompt = (pp.OneOrMore(pp.Or([prompt_part, quoted_fragment, empty])) + pp.StringEnd()) \
|
||||
.set_parse_action(lambda x: Prompt(x))
|
||||
|
||||
#print("parsing test:", prompt.parse_string("spaced eyes--"))
|
||||
#print("parsing test:", prompt.parse_string("eyes--"))
|
||||
|
||||
# weighted blend of prompts
|
||||
# ("promptA", "promptB").blend(a, b) where "promptA" and "promptB" are valid prompts and a and b are float or
|
||||
# int weights.
|
||||
# can specify more terms eg ("promptA", "promptB", "promptC").blend(a,b,c)
|
||||
|
||||
def make_prompt_from_quoted_string(x):
|
||||
#print(' got quoted prompt', x)
|
||||
|
||||
x_unquoted = x[0][1:-1]
|
||||
if len(x_unquoted.strip()) == 0:
|
||||
# print(' b : just an empty string')
|
||||
return Prompt([Fragment('')])
|
||||
# print(' b parsing ', c_unquoted)
|
||||
x_parsed = prompt.parse_string(x_unquoted)
|
||||
#print(" quoted prompt was parsed to", type(x_parsed),":", x_parsed)
|
||||
return x_parsed[0]
|
||||
|
||||
quoted_prompt = pp.dbl_quoted_string.set_parse_action(make_prompt_from_quoted_string)
|
||||
quoted_prompt.set_name('quoted_prompt')
|
||||
|
||||
debug_blend=False
|
||||
blend_terms = pp.delimited_list(quoted_prompt).set_name('blend_terms').set_debug(debug_blend)
|
||||
blend_weights = (pp.delimited_list(number) + pp.Optional(pp.Char(",").suppress() + "no_normalize")).set_name('blend_weights').set_debug(debug_blend)
|
||||
blend = pp.Group(lparen + pp.Group(blend_terms) + rparen
|
||||
+ pp.Literal(".blend").suppress()
|
||||
+ lparen + pp.Group(blend_weights) + rparen).set_name('blend')
|
||||
blend.set_debug(debug_blend)
|
||||
|
||||
def make_blend(x):
|
||||
prompts = x[0][0]
|
||||
weights = x[0][1]
|
||||
normalize = True
|
||||
if weights[-1] == 'no_normalize':
|
||||
normalize = False
|
||||
weights = weights[:-1]
|
||||
return Blend(prompts=prompts, weights=weights, normalize_weights=normalize)
|
||||
|
||||
blend.set_parse_action(make_blend)
|
||||
|
||||
conjunction_terms = blend_terms.copy().set_name('conjunction_terms')
|
||||
conjunction_weights = blend_weights.copy().set_name('conjunction_weights')
|
||||
conjunction_with_parens_and_quotes = pp.Group(lparen + pp.Group(conjunction_terms) + rparen
|
||||
+ pp.Literal(".and").suppress()
|
||||
+ lparen + pp.Optional(pp.Group(conjunction_weights)) + rparen).set_name('conjunction')
|
||||
def make_conjunction(x):
|
||||
parts_raw = x[0][0]
|
||||
weights = x[0][1] if len(x[0])>1 else [1.0]*len(parts_raw)
|
||||
parts = [part for part in parts_raw]
|
||||
return Conjunction(parts, weights)
|
||||
conjunction_with_parens_and_quotes.set_parse_action(make_conjunction)
|
||||
|
||||
implicit_conjunction = pp.OneOrMore(blend | prompt).set_name('implicit_conjunction')
|
||||
implicit_conjunction.set_parse_action(lambda x: Conjunction(x))
|
||||
|
||||
conjunction = conjunction_with_parens_and_quotes | implicit_conjunction
|
||||
conjunction.set_debug(False)
|
||||
|
||||
# top-level is a conjunction of one or more blends or prompts
|
||||
return conjunction, prompt
|
||||
|
||||
|
||||
|
||||
def split_weighted_subprompts(text, skip_normalize=False)->list:
|
||||
"""
|
||||
Legacy blend parsing.
|
||||
|
||||
grabs all text up to the first occurrence of ':'
|
||||
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
|
||||
if ':' has no value defined, defaults to 1.0
|
||||
repeats until no text remaining
|
||||
"""
|
||||
prompt_parser = re.compile("""
|
||||
(?P<prompt> # capture group for 'prompt'
|
||||
(?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:'
|
||||
) # end 'prompt'
|
||||
(?: # non-capture group
|
||||
:+ # match one or more ':' characters
|
||||
(?P<weight> # capture group for 'weight'
|
||||
-?\d+(?:\.\d+)? # match positive or negative integer or decimal number
|
||||
)? # end weight capture group, make optional
|
||||
\s* # strip spaces after weight
|
||||
| # OR
|
||||
$ # else, if no ':' then match end of line
|
||||
) # end non-capture group
|
||||
""", re.VERBOSE)
|
||||
parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float(
|
||||
match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)]
|
||||
if skip_normalize:
|
||||
return parsed_prompts
|
||||
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
|
||||
if weight_sum == 0:
|
||||
print(
|
||||
"Warning: Subprompt weights add up to zero. Discarding and using even weights instead.")
|
||||
equal_weight = 1 / max(len(parsed_prompts), 1)
|
||||
return [(x[0], equal_weight) for x in parsed_prompts]
|
||||
return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
|
||||
|
||||
|
||||
# shows how the prompt is tokenized
|
||||
# usually tokens have '</w>' to indicate end-of-word,
|
||||
# but for readability it has been replaced with ' '
|
||||
def log_tokenization(text, model, log=False, weight=1):
|
||||
if not log:
|
||||
return
|
||||
tokens = model.cond_stage_model.tokenizer._tokenize(text)
|
||||
tokenized = ""
|
||||
discarded = ""
|
||||
usedTokens = 0
|
||||
totalTokens = len(tokens)
|
||||
for i in range(0, totalTokens):
|
||||
token = tokens[i].replace('</w>', 'x` ')
|
||||
# alternate color
|
||||
s = (usedTokens % 6) + 1
|
||||
if i < model.cond_stage_model.max_length:
|
||||
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
|
||||
usedTokens += 1
|
||||
else: # over max token length
|
||||
discarded = discarded + f"\x1b[0;3{s};40m{token}"
|
||||
print(f"\n>> Tokens ({usedTokens}), Weight ({weight:.2f}):\n{tokenized}\x1b[0m")
|
||||
if discarded != "":
|
||||
print(
|
||||
f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m"
|
||||
)
|
238
ldm/models/diffusion/cross_attention_control.py
Normal file
238
ldm/models/diffusion/cross_attention_control.py
Normal file
@ -0,0 +1,238 @@
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
|
||||
# adapted from bloc97's CrossAttentionControl colab
|
||||
# https://github.com/bloc97/CrossAttentionControl
|
||||
|
||||
class CrossAttentionControl:
|
||||
|
||||
class Arguments:
|
||||
def __init__(self, edited_conditioning: torch.Tensor, edit_opcodes: list[tuple], edit_options: dict):
|
||||
"""
|
||||
:param edited_conditioning: if doing cross-attention control, the edited conditioning [1 x 77 x 768]
|
||||
:param edit_opcodes: if doing cross-attention control, a list of difflib.SequenceMatcher-like opcodes describing how to map original conditioning tokens to edited conditioning tokens (only the 'equal' opcode is required)
|
||||
:param edit_options: if doing cross-attention control, per-edit options. there should be 1 item in edit_options for each item in edit_opcodes.
|
||||
"""
|
||||
# todo: rewrite this to take embedding fragments rather than a single edited_conditioning vector
|
||||
self.edited_conditioning = edited_conditioning
|
||||
self.edit_opcodes = edit_opcodes
|
||||
|
||||
if edited_conditioning is not None:
|
||||
assert len(edit_opcodes) == len(edit_options), \
|
||||
"there must be 1 edit_options dict for each edit_opcodes tuple"
|
||||
non_none_edit_options = [x for x in edit_options if x is not None]
|
||||
assert len(non_none_edit_options)>0, "missing edit_options"
|
||||
if len(non_none_edit_options)>1:
|
||||
print('warning: cross-attention control options are not working properly for >1 edit')
|
||||
self.edit_options = non_none_edit_options[0]
|
||||
|
||||
class Context:
|
||||
def __init__(self, arguments: 'CrossAttentionControl.Arguments', step_count: int):
|
||||
"""
|
||||
:param arguments: Arguments for the cross-attention control process
|
||||
:param step_count: The absolute total number of steps of diffusion (for img2img this is likely larger than the number of steps that will actually run)
|
||||
"""
|
||||
self.arguments = arguments
|
||||
self.step_count = step_count
|
||||
|
||||
@classmethod
|
||||
def remove_cross_attention_control(cls, model):
|
||||
cls.remove_attention_function(model)
|
||||
|
||||
@classmethod
|
||||
def setup_cross_attention_control(cls, model,
|
||||
cross_attention_control_args: Arguments
|
||||
):
|
||||
"""
|
||||
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
||||
|
||||
:param model: The unet model to inject into.
|
||||
:param cross_attention_control_args: Arugments passeed to the CrossAttentionControl implementations
|
||||
:return: None
|
||||
"""
|
||||
|
||||
# adapted from init_attention_edit
|
||||
device = cross_attention_control_args.edited_conditioning.device
|
||||
|
||||
# urgh. should this be hardcoded?
|
||||
max_length = 77
|
||||
# mask=1 means use base prompt attention, mask=0 means use edited prompt attention
|
||||
mask = torch.zeros(max_length)
|
||||
indices_target = torch.arange(max_length, dtype=torch.long)
|
||||
indices = torch.zeros(max_length, dtype=torch.long)
|
||||
for name, a0, a1, b0, b1 in cross_attention_control_args.edit_opcodes:
|
||||
if b0 < max_length:
|
||||
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
|
||||
# these tokens have not been edited
|
||||
indices[b0:b1] = indices_target[a0:a1]
|
||||
mask[b0:b1] = 1
|
||||
|
||||
for m in cls.get_attention_modules(model, cls.CrossAttentionType.SELF):
|
||||
m.last_attn_slice_mask = None
|
||||
m.last_attn_slice_indices = None
|
||||
|
||||
for m in cls.get_attention_modules(model, cls.CrossAttentionType.TOKENS):
|
||||
m.last_attn_slice_mask = mask.to(device)
|
||||
m.last_attn_slice_indices = indices.to(device)
|
||||
|
||||
cls.inject_attention_function(model)
|
||||
|
||||
|
||||
class CrossAttentionType(Enum):
|
||||
SELF = 1
|
||||
TOKENS = 2
|
||||
|
||||
@classmethod
|
||||
def get_active_cross_attention_control_types_for_step(cls, context: 'CrossAttentionControl.Context', percent_through:float=None)\
|
||||
-> list['CrossAttentionControl.CrossAttentionType']:
|
||||
"""
|
||||
Should cross-attention control be applied on the given step?
|
||||
:param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0.
|
||||
:return: A list of attention types that cross-attention control should be performed for on the given step. May be [].
|
||||
"""
|
||||
if percent_through is None:
|
||||
return [cls.CrossAttentionType.SELF, cls.CrossAttentionType.TOKENS]
|
||||
|
||||
opts = context.arguments.edit_options
|
||||
to_control = []
|
||||
if opts['s_start'] <= percent_through and percent_through < opts['s_end']:
|
||||
to_control.append(cls.CrossAttentionType.SELF)
|
||||
if opts['t_start'] <= percent_through and percent_through < opts['t_end']:
|
||||
to_control.append(cls.CrossAttentionType.TOKENS)
|
||||
return to_control
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_attention_modules(cls, model, which: CrossAttentionType):
|
||||
which_attn = "attn1" if which is cls.CrossAttentionType.SELF else "attn2"
|
||||
return [module for name, module in model.named_modules() if
|
||||
type(module).__name__ == "CrossAttention" and which_attn in name]
|
||||
|
||||
@classmethod
|
||||
def clear_requests(cls, model):
|
||||
self_attention_modules = cls.get_attention_modules(model, cls.CrossAttentionType.SELF)
|
||||
tokens_attention_modules = cls.get_attention_modules(model, cls.CrossAttentionType.TOKENS)
|
||||
for m in self_attention_modules+tokens_attention_modules:
|
||||
m.save_last_attn_slice = False
|
||||
m.use_last_attn_slice = False
|
||||
|
||||
@classmethod
|
||||
def request_save_attention_maps(cls, model, cross_attention_type: CrossAttentionType):
|
||||
modules = cls.get_attention_modules(model, cross_attention_type)
|
||||
for m in modules:
|
||||
# clear out the saved slice in case the outermost dim changes
|
||||
m.last_attn_slice = None
|
||||
m.save_last_attn_slice = True
|
||||
|
||||
@classmethod
|
||||
def request_apply_saved_attention_maps(cls, model, cross_attention_type: CrossAttentionType):
|
||||
modules = cls.get_attention_modules(model, cross_attention_type)
|
||||
for m in modules:
|
||||
m.use_last_attn_slice = True
|
||||
|
||||
|
||||
|
||||
@classmethod
|
||||
def inject_attention_function(cls, unet):
|
||||
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
|
||||
|
||||
def attention_slice_wrangler(self, attention_scores, suggested_attention_slice, dim, offset, slice_size):
|
||||
|
||||
#print("in wrangler with suggested_attention_slice shape", suggested_attention_slice.shape, "dim", dim)
|
||||
|
||||
attn_slice = suggested_attention_slice
|
||||
if dim is not None:
|
||||
start = offset
|
||||
end = start+slice_size
|
||||
#print(f"in wrangler, sliced dim {dim} {start}-{end}, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}")
|
||||
#else:
|
||||
# print(f"in wrangler, whole, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}")
|
||||
|
||||
|
||||
if self.use_last_attn_slice:
|
||||
this_attn_slice = attn_slice
|
||||
if self.last_attn_slice_mask is not None:
|
||||
# indices and mask operate on dim=2, no need to slice
|
||||
base_attn_slice_full = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices)
|
||||
base_attn_slice_mask = self.last_attn_slice_mask
|
||||
if dim is None:
|
||||
base_attn_slice = base_attn_slice_full
|
||||
#print("using whole base slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape)
|
||||
elif dim == 0:
|
||||
base_attn_slice = base_attn_slice_full[start:end]
|
||||
#print("using base dim 0 slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape)
|
||||
elif dim == 1:
|
||||
base_attn_slice = base_attn_slice_full[:, start:end]
|
||||
#print("using base dim 1 slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape)
|
||||
|
||||
attn_slice = this_attn_slice * (1 - base_attn_slice_mask) + \
|
||||
base_attn_slice * base_attn_slice_mask
|
||||
else:
|
||||
if dim is None:
|
||||
attn_slice = self.last_attn_slice
|
||||
#print("took whole slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
|
||||
elif dim == 0:
|
||||
attn_slice = self.last_attn_slice[start:end]
|
||||
#print("took dim 0 slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
|
||||
elif dim == 1:
|
||||
attn_slice = self.last_attn_slice[:, start:end]
|
||||
#print("took dim 1 slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
|
||||
|
||||
if self.save_last_attn_slice:
|
||||
if dim is None:
|
||||
self.last_attn_slice = attn_slice
|
||||
elif dim == 0:
|
||||
# dynamically grow last_attn_slice if needed
|
||||
if self.last_attn_slice is None:
|
||||
self.last_attn_slice = attn_slice
|
||||
#print("no last_attn_slice: shape now", self.last_attn_slice.shape)
|
||||
elif self.last_attn_slice.shape[0] == start:
|
||||
self.last_attn_slice = torch.cat([self.last_attn_slice, attn_slice], dim=0)
|
||||
assert(self.last_attn_slice.shape[0] == end)
|
||||
#print("last_attn_slice too small, appended dim 0 shape", attn_slice.shape, ", shape now", self.last_attn_slice.shape)
|
||||
else:
|
||||
# no need to grow
|
||||
self.last_attn_slice[start:end] = attn_slice
|
||||
#print("last_attn_slice shape is fine, setting dim 0 shape", attn_slice.shape, ", shape now", self.last_attn_slice.shape)
|
||||
|
||||
elif dim == 1:
|
||||
# dynamically grow last_attn_slice if needed
|
||||
if self.last_attn_slice is None:
|
||||
self.last_attn_slice = attn_slice
|
||||
elif self.last_attn_slice.shape[1] == start:
|
||||
self.last_attn_slice = torch.cat([self.last_attn_slice, attn_slice], dim=1)
|
||||
assert(self.last_attn_slice.shape[1] == end)
|
||||
else:
|
||||
# no need to grow
|
||||
self.last_attn_slice[:, start:end] = attn_slice
|
||||
|
||||
if self.use_last_attn_weights and self.last_attn_slice_weights is not None:
|
||||
if dim is None:
|
||||
weights = self.last_attn_slice_weights
|
||||
elif dim == 0:
|
||||
weights = self.last_attn_slice_weights[start:end]
|
||||
elif dim == 1:
|
||||
weights = self.last_attn_slice_weights[:, start:end]
|
||||
attn_slice = attn_slice * weights
|
||||
|
||||
return attn_slice
|
||||
|
||||
for name, module in unet.named_modules():
|
||||
module_name = type(module).__name__
|
||||
if module_name == "CrossAttention":
|
||||
module.last_attn_slice = None
|
||||
module.last_attn_slice_indices = None
|
||||
module.last_attn_slice_mask = None
|
||||
module.use_last_attn_weights = False
|
||||
module.use_last_attn_slice = False
|
||||
module.save_last_attn_slice = False
|
||||
module.set_attention_slice_wrangler(attention_slice_wrangler)
|
||||
|
||||
@classmethod
|
||||
def remove_attention_function(cls, unet):
|
||||
for name, module in unet.named_modules():
|
||||
module_name = type(module).__name__
|
||||
if module_name == "CrossAttention":
|
||||
module.set_attention_slice_wrangler(None)
|
||||
|
@ -1,10 +1,7 @@
|
||||
"""SAMPLING ONLY."""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from functools import partial
|
||||
from ldm.invoke.devices import choose_torch_device
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from ldm.models.diffusion.sampler import Sampler
|
||||
from ldm.modules.diffusionmodules.util import noise_like
|
||||
|
||||
@ -12,6 +9,21 @@ class DDIMSampler(Sampler):
|
||||
def __init__(self, model, schedule='linear', device=None, **kwargs):
|
||||
super().__init__(model,schedule,model.num_timesteps,device)
|
||||
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model,
|
||||
model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond))
|
||||
|
||||
def prepare_to_sample(self, t_enc, **kwargs):
|
||||
super().prepare_to_sample(t_enc, **kwargs)
|
||||
|
||||
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
|
||||
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc)
|
||||
|
||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = all_timesteps_count)
|
||||
else:
|
||||
self.invokeai_diffuser.remove_cross_attention_control()
|
||||
|
||||
|
||||
# This is the central routine
|
||||
@torch.no_grad()
|
||||
def p_sample(
|
||||
@ -29,6 +41,7 @@ class DDIMSampler(Sampler):
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
step_count:int=1000, # total number of steps
|
||||
**kwargs,
|
||||
):
|
||||
b, *_, device = *x.shape, x.device
|
||||
@ -37,15 +50,14 @@ class DDIMSampler(Sampler):
|
||||
unconditional_conditioning is None
|
||||
or unconditional_guidance_scale == 1.0
|
||||
):
|
||||
# damian0815 would like to know when/if this code path is used
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||
e_t = e_t_uncond + unconditional_guidance_scale * (
|
||||
e_t - e_t_uncond
|
||||
)
|
||||
step_index = step_count-(index+1)
|
||||
e_t = self.invokeai_diffuser.do_diffusion_step(x, t,
|
||||
unconditional_conditioning, c,
|
||||
unconditional_guidance_scale,
|
||||
step_index=step_index)
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == 'eps'
|
||||
|
@ -820,21 +820,21 @@ class LatentDiffusion(DDPM):
|
||||
)
|
||||
return self.scale_factor * z
|
||||
|
||||
def get_learned_conditioning(self, c):
|
||||
def get_learned_conditioning(self, c, **kwargs):
|
||||
if self.cond_stage_forward is None:
|
||||
if hasattr(self.cond_stage_model, 'encode') and callable(
|
||||
self.cond_stage_model.encode
|
||||
):
|
||||
c = self.cond_stage_model.encode(
|
||||
c, embedding_manager=self.embedding_manager
|
||||
c, embedding_manager=self.embedding_manager, **kwargs
|
||||
)
|
||||
if isinstance(c, DiagonalGaussianDistribution):
|
||||
c = c.mode()
|
||||
else:
|
||||
c = self.cond_stage_model(c)
|
||||
c = self.cond_stage_model(c, **kwargs)
|
||||
else:
|
||||
assert hasattr(self.cond_stage_model, self.cond_stage_forward)
|
||||
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
|
||||
c = getattr(self.cond_stage_model, self.cond_stage_forward)(c, **kwargs)
|
||||
return c
|
||||
|
||||
def meshgrid(self, h, w):
|
||||
|
@ -1,16 +1,12 @@
|
||||
"""wrapper around part of Katherine Crowson's k-diffusion library, making it call compatible with other Samplers"""
|
||||
|
||||
import k_diffusion as K
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ldm.invoke.devices import choose_torch_device
|
||||
from ldm.models.diffusion.sampler import Sampler
|
||||
from ldm.util import rand_perlin_2d
|
||||
from ldm.modules.diffusionmodules.util import (
|
||||
make_ddim_sampling_parameters,
|
||||
make_ddim_timesteps,
|
||||
noise_like,
|
||||
extract_into_tensor,
|
||||
)
|
||||
from torch import nn
|
||||
|
||||
from .sampler import Sampler
|
||||
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
|
||||
|
||||
def cfg_apply_threshold(result, threshold = 0.0, scale = 0.7):
|
||||
if threshold <= 0.0:
|
||||
@ -33,12 +29,24 @@ class CFGDenoiser(nn.Module):
|
||||
self.threshold = threshold
|
||||
self.warmup_max = warmup
|
||||
self.warmup = max(warmup / 10, 1)
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(model,
|
||||
model_forward_callback=lambda x, sigma, cond: self.inner_model(x, sigma, cond=cond))
|
||||
|
||||
def prepare_to_sample(self, t_enc, **kwargs):
|
||||
|
||||
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
|
||||
|
||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = t_enc)
|
||||
else:
|
||||
self.invokeai_diffuser.remove_cross_attention_control()
|
||||
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||
x_in = torch.cat([x] * 2)
|
||||
sigma_in = torch.cat([sigma] * 2)
|
||||
cond_in = torch.cat([uncond, cond])
|
||||
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
|
||||
|
||||
next_x = self.invokeai_diffuser.do_diffusion_step(x, sigma, uncond, cond, cond_scale)
|
||||
|
||||
# apply threshold
|
||||
if self.warmup < self.warmup_max:
|
||||
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
|
||||
self.warmup += 1
|
||||
@ -46,7 +54,8 @@ class CFGDenoiser(nn.Module):
|
||||
thresh = self.threshold
|
||||
if thresh > self.threshold:
|
||||
thresh = self.threshold
|
||||
return cfg_apply_threshold(uncond + (cond - uncond) * cond_scale, thresh)
|
||||
return cfg_apply_threshold(next_x, thresh)
|
||||
|
||||
|
||||
|
||||
class KSampler(Sampler):
|
||||
@ -61,16 +70,6 @@ class KSampler(Sampler):
|
||||
self.ds = None
|
||||
self.s_in = None
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||
x_in = torch.cat([x] * 2)
|
||||
sigma_in = torch.cat([sigma] * 2)
|
||||
cond_in = torch.cat([uncond, cond])
|
||||
uncond, cond = self.inner_model(
|
||||
x_in, sigma_in, cond=cond_in
|
||||
).chunk(2)
|
||||
return uncond + (cond - uncond) * cond_scale
|
||||
|
||||
|
||||
def make_schedule(
|
||||
self,
|
||||
ddim_num_steps,
|
||||
@ -118,6 +117,7 @@ class KSampler(Sampler):
|
||||
use_original_steps=False,
|
||||
init_latent = None,
|
||||
mask = None,
|
||||
**kwargs
|
||||
):
|
||||
samples,_ = self.sample(
|
||||
batch_size = 1,
|
||||
@ -129,7 +129,8 @@ class KSampler(Sampler):
|
||||
unconditional_conditioning = unconditional_conditioning,
|
||||
img_callback = img_callback,
|
||||
x0 = init_latent,
|
||||
mask = mask
|
||||
mask = mask,
|
||||
**kwargs
|
||||
)
|
||||
return samples
|
||||
|
||||
@ -163,6 +164,7 @@ class KSampler(Sampler):
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
extra_conditioning_info=None,
|
||||
threshold = 0,
|
||||
perlin = 0,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
@ -181,7 +183,6 @@ class KSampler(Sampler):
|
||||
)
|
||||
|
||||
# sigmas are set up in make_schedule - we take the last steps items
|
||||
total_steps = len(self.sigmas)
|
||||
sigmas = self.sigmas[-S-1:]
|
||||
|
||||
# x_T is variation noise. When an init image is provided (in x0) we need to add
|
||||
@ -195,19 +196,21 @@ class KSampler(Sampler):
|
||||
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0]
|
||||
|
||||
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
|
||||
model_wrap_cfg.prepare_to_sample(S, extra_conditioning_info=extra_conditioning_info)
|
||||
extra_args = {
|
||||
'cond': conditioning,
|
||||
'uncond': unconditional_conditioning,
|
||||
'cond_scale': unconditional_guidance_scale,
|
||||
}
|
||||
print(f'>> Sampling with k_{self.schedule} starting at step {len(self.sigmas)-S-1} of {len(self.sigmas)-1} ({S} new sampling steps)')
|
||||
return (
|
||||
sampling_result = (
|
||||
K.sampling.__dict__[f'sample_{self.schedule}'](
|
||||
model_wrap_cfg, x, sigmas, extra_args=extra_args,
|
||||
callback=route_callback
|
||||
),
|
||||
None,
|
||||
)
|
||||
return sampling_result
|
||||
|
||||
# this code will support inpainting if and when ksampler API modified or
|
||||
# a workaround is found.
|
||||
@ -220,6 +223,7 @@ class KSampler(Sampler):
|
||||
index,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
extra_conditioning_info=None,
|
||||
**kwargs,
|
||||
):
|
||||
if self.model_wrap is None:
|
||||
@ -245,6 +249,7 @@ class KSampler(Sampler):
|
||||
# so the actual formula for indexing into sigmas:
|
||||
# sigma_index = (steps-index)
|
||||
s_index = t_enc - index - 1
|
||||
self.model_wrap.prepare_to_sample(s_index, extra_conditioning_info=extra_conditioning_info)
|
||||
img = K.sampling.__dict__[f'_{self.schedule}'](
|
||||
self.model_wrap,
|
||||
img,
|
||||
@ -269,7 +274,7 @@ class KSampler(Sampler):
|
||||
else:
|
||||
return x
|
||||
|
||||
def prepare_to_sample(self,t_enc):
|
||||
def prepare_to_sample(self,t_enc,**kwargs):
|
||||
self.t_enc = t_enc
|
||||
self.model_wrap = None
|
||||
self.ds = None
|
||||
|
@ -5,6 +5,7 @@ import numpy as np
|
||||
from tqdm import tqdm
|
||||
from functools import partial
|
||||
from ldm.invoke.devices import choose_torch_device
|
||||
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from ldm.models.diffusion.sampler import Sampler
|
||||
from ldm.modules.diffusionmodules.util import noise_like
|
||||
|
||||
@ -13,6 +14,21 @@ class PLMSSampler(Sampler):
|
||||
def __init__(self, model, schedule='linear', device=None, **kwargs):
|
||||
super().__init__(model,schedule,model.num_timesteps, device)
|
||||
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(self.model,
|
||||
model_forward_callback = lambda x, sigma, cond: self.model.apply_model(x, sigma, cond))
|
||||
|
||||
def prepare_to_sample(self, t_enc, **kwargs):
|
||||
super().prepare_to_sample(t_enc, **kwargs)
|
||||
|
||||
extra_conditioning_info = kwargs.get('extra_conditioning_info', None)
|
||||
all_timesteps_count = kwargs.get('all_timesteps_count', t_enc)
|
||||
|
||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||
self.invokeai_diffuser.setup_cross_attention_control(extra_conditioning_info, step_count = all_timesteps_count)
|
||||
else:
|
||||
self.invokeai_diffuser.remove_cross_attention_control()
|
||||
|
||||
|
||||
# this is the essential routine
|
||||
@torch.no_grad()
|
||||
def p_sample(
|
||||
@ -32,6 +48,7 @@ class PLMSSampler(Sampler):
|
||||
unconditional_conditioning=None,
|
||||
old_eps=[],
|
||||
t_next=None,
|
||||
step_count:int=1000, # total number of steps
|
||||
**kwargs,
|
||||
):
|
||||
b, *_, device = *x.shape, x.device
|
||||
@ -41,17 +58,15 @@ class PLMSSampler(Sampler):
|
||||
unconditional_conditioning is None
|
||||
or unconditional_guidance_scale == 1.0
|
||||
):
|
||||
# damian0815 would like to know when/if this code path is used
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
e_t_uncond, e_t = self.model.apply_model(
|
||||
x_in, t_in, c_in
|
||||
).chunk(2)
|
||||
e_t = e_t_uncond + unconditional_guidance_scale * (
|
||||
e_t - e_t_uncond
|
||||
)
|
||||
# step_index counts in the opposite direction to index
|
||||
step_index = step_count-(index+1)
|
||||
e_t = self.invokeai_diffuser.do_diffusion_step(x, t,
|
||||
unconditional_conditioning, c,
|
||||
unconditional_guidance_scale,
|
||||
step_index=step_index)
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == 'eps'
|
||||
|
@ -4,6 +4,8 @@ ldm.models.diffusion.sampler
|
||||
Base class for ldm.models.diffusion.ddim, ldm.models.diffusion.ksampler, etc
|
||||
|
||||
'''
|
||||
from math import ceil
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
@ -190,6 +192,7 @@ class Sampler(object):
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
steps=S,
|
||||
**kwargs
|
||||
)
|
||||
return samples, intermediates
|
||||
|
||||
@ -214,6 +217,7 @@ class Sampler(object):
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
steps=None,
|
||||
**kwargs
|
||||
):
|
||||
b = shape[0]
|
||||
time_range = (
|
||||
@ -231,7 +235,7 @@ class Sampler(object):
|
||||
dynamic_ncols=True,
|
||||
)
|
||||
old_eps = []
|
||||
self.prepare_to_sample(t_enc=total_steps)
|
||||
self.prepare_to_sample(t_enc=total_steps,all_timesteps_count=steps,**kwargs)
|
||||
img = self.get_initial_image(x_T,shape,total_steps)
|
||||
|
||||
# probably don't need this at all
|
||||
@ -274,6 +278,7 @@ class Sampler(object):
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
old_eps=old_eps,
|
||||
t_next=ts_next,
|
||||
step_count=steps
|
||||
)
|
||||
img, pred_x0, e_t = outs
|
||||
|
||||
@ -305,6 +310,8 @@ class Sampler(object):
|
||||
use_original_steps=False,
|
||||
init_latent = None,
|
||||
mask = None,
|
||||
all_timesteps_count = None,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
timesteps = (
|
||||
@ -321,7 +328,7 @@ class Sampler(object):
|
||||
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
|
||||
x_dec = x_latent
|
||||
x0 = init_latent
|
||||
self.prepare_to_sample(t_enc=total_steps)
|
||||
self.prepare_to_sample(t_enc=total_steps, all_timesteps_count=all_timesteps_count, **kwargs)
|
||||
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
@ -353,6 +360,7 @@ class Sampler(object):
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
t_next = ts_next,
|
||||
step_count=len(self.ddim_timesteps)
|
||||
)
|
||||
|
||||
x_dec, pred_x0, e_t = outs
|
||||
@ -411,3 +419,4 @@ class Sampler(object):
|
||||
return self.model.inner_model.q_sample(x0,ts)
|
||||
'''
|
||||
return self.model.q_sample(x0,ts)
|
||||
|
||||
|
176
ldm/models/diffusion/shared_invokeai_diffusion.py
Normal file
176
ldm/models/diffusion/shared_invokeai_diffusion.py
Normal file
@ -0,0 +1,176 @@
|
||||
from math import ceil
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from ldm.models.diffusion.cross_attention_control import CrossAttentionControl
|
||||
|
||||
|
||||
class InvokeAIDiffuserComponent:
|
||||
'''
|
||||
The aim of this component is to provide a single place for code that can be applied identically to
|
||||
all InvokeAI diffusion procedures.
|
||||
|
||||
At the moment it includes the following features:
|
||||
* Cross Attention Control ("prompt2prompt")
|
||||
'''
|
||||
|
||||
|
||||
class ExtraConditioningInfo:
|
||||
def __init__(self, cross_attention_control_args: Optional[CrossAttentionControl.Arguments]):
|
||||
self.cross_attention_control_args = cross_attention_control_args
|
||||
|
||||
@property
|
||||
def wants_cross_attention_control(self):
|
||||
return self.cross_attention_control_args is not None
|
||||
|
||||
def __init__(self, model, model_forward_callback:
|
||||
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
|
||||
):
|
||||
"""
|
||||
:param model: the unet model to pass through to cross attention control
|
||||
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
|
||||
"""
|
||||
self.model = model
|
||||
self.model_forward_callback = model_forward_callback
|
||||
|
||||
|
||||
def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int):
|
||||
self.conditioning = conditioning
|
||||
self.cross_attention_control_context = CrossAttentionControl.Context(
|
||||
arguments=self.conditioning.cross_attention_control_args,
|
||||
step_count=step_count
|
||||
)
|
||||
CrossAttentionControl.setup_cross_attention_control(self.model,
|
||||
cross_attention_control_args=self.conditioning.cross_attention_control_args
|
||||
)
|
||||
#todo: refactor edited_conditioning, edit_opcodes, edit_options into a struct
|
||||
#todo: apply edit_options using step_count
|
||||
|
||||
|
||||
def remove_cross_attention_control(self):
|
||||
self.conditioning = None
|
||||
self.cross_attention_control_context = None
|
||||
CrossAttentionControl.remove_cross_attention_control(self.model)
|
||||
|
||||
def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor,
|
||||
unconditioning: torch.Tensor, conditioning: torch.Tensor,
|
||||
unconditional_guidance_scale: float,
|
||||
step_index: int=None
|
||||
):
|
||||
"""
|
||||
:param x: Current latents
|
||||
:param sigma: aka t, passed to the internal model to control how much denoising will occur
|
||||
:param unconditioning: [B x 77 x 768] embeddings for unconditioned output
|
||||
:param conditioning: [B x 77 x 768] embeddings for conditioned output
|
||||
:param unconditional_guidance_scale: aka CFG scale, controls how much effect the conditioning tensor has
|
||||
:param step_index: Counts upwards from 0 to (step_count-1) (as passed to setup_cross_attention_control, if using). May be called multiple times for a single step, therefore do not assume that its value will monotically increase.
|
||||
:return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning.
|
||||
"""
|
||||
|
||||
CrossAttentionControl.clear_requests(self.model)
|
||||
cross_attention_control_types_to_do = []
|
||||
|
||||
if self.cross_attention_control_context is not None:
|
||||
if step_index is not None:
|
||||
# percent_through will never reach 1.0 (but this is intended)
|
||||
percent_through = float(step_index) / float(self.cross_attention_control_context.step_count)
|
||||
else:
|
||||
# find the current sigma in the sigma sequence
|
||||
# todo: this doesn't work with k_dpm_2 because the sigma used jumps around in the sequence
|
||||
sigma_index = torch.nonzero(self.model.sigmas <= sigma)[-1]
|
||||
# flip because sigmas[0] is for the fully denoised image
|
||||
# percent_through must be <1
|
||||
percent_through = 1.0 - float(sigma_index.item() + 1) / float(self.model.sigmas.shape[0])
|
||||
#print('estimated percent_through', percent_through, 'from sigma', sigma.item())
|
||||
cross_attention_control_types_to_do = CrossAttentionControl.get_active_cross_attention_control_types_for_step(self.cross_attention_control_context, percent_through)
|
||||
|
||||
if len(cross_attention_control_types_to_do)==0:
|
||||
#print('not doing cross attention control')
|
||||
# faster batched path
|
||||
x_twice = torch.cat([x]*2)
|
||||
sigma_twice = torch.cat([sigma]*2)
|
||||
both_conditionings = torch.cat([unconditioning, conditioning])
|
||||
unconditioned_next_x, conditioned_next_x = self.model_forward_callback(x_twice, sigma_twice, both_conditionings).chunk(2)
|
||||
else:
|
||||
#print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
|
||||
# slower non-batched path (20% slower on mac MPS)
|
||||
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of
|
||||
# unconditioned_next_x causes attention maps to *also* be saved for the unconditioned_next_x.
|
||||
# This messes app their application later, due to mismatched shape of dim 0 (seems to be 16 for batched vs. 8)
|
||||
# (For the batched invocation the `wrangler` function gets attention tensor with shape[0]=16,
|
||||
# representing batched uncond + cond, but then when it comes to applying the saved attention, the
|
||||
# wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.)
|
||||
# todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well.
|
||||
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning)
|
||||
|
||||
# process x using the original prompt, saving the attention maps
|
||||
for type in cross_attention_control_types_to_do:
|
||||
CrossAttentionControl.request_save_attention_maps(self.model, type)
|
||||
_ = self.model_forward_callback(x, sigma, conditioning)
|
||||
CrossAttentionControl.clear_requests(self.model)
|
||||
|
||||
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
|
||||
for type in cross_attention_control_types_to_do:
|
||||
CrossAttentionControl.request_apply_saved_attention_maps(self.model, type)
|
||||
edited_conditioning = self.conditioning.cross_attention_control_args.edited_conditioning
|
||||
conditioned_next_x = self.model_forward_callback(x, sigma, edited_conditioning)
|
||||
CrossAttentionControl.clear_requests(self.model)
|
||||
|
||||
|
||||
# to scale how much effect conditioning has, calculate the changes it does and then scale that
|
||||
scaled_delta = (conditioned_next_x - unconditioned_next_x) * unconditional_guidance_scale
|
||||
combined_next_x = unconditioned_next_x + scaled_delta
|
||||
|
||||
return combined_next_x
|
||||
|
||||
# todo: make this work
|
||||
@classmethod
|
||||
def apply_conjunction(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale):
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2) # aka sigmas
|
||||
|
||||
deltas = None
|
||||
uncond_latents = None
|
||||
weighted_cond_list = c_or_weighted_c_list if type(c_or_weighted_c_list) is list else [(c_or_weighted_c_list, 1)]
|
||||
|
||||
# below is fugly omg
|
||||
num_actual_conditionings = len(c_or_weighted_c_list)
|
||||
conditionings = [uc] + [c for c,weight in weighted_cond_list]
|
||||
weights = [1] + [weight for c,weight in weighted_cond_list]
|
||||
chunk_count = ceil(len(conditionings)/2)
|
||||
deltas = None
|
||||
for chunk_index in range(chunk_count):
|
||||
offset = chunk_index*2
|
||||
chunk_size = min(2, len(conditionings)-offset)
|
||||
|
||||
if chunk_size == 1:
|
||||
c_in = conditionings[offset]
|
||||
latents_a = forward_func(x_in[:-1], t_in[:-1], c_in)
|
||||
latents_b = None
|
||||
else:
|
||||
c_in = torch.cat(conditionings[offset:offset+2])
|
||||
latents_a, latents_b = forward_func(x_in, t_in, c_in).chunk(2)
|
||||
|
||||
# first chunk is guaranteed to be 2 entries: uncond_latents + first conditioining
|
||||
if chunk_index == 0:
|
||||
uncond_latents = latents_a
|
||||
deltas = latents_b - uncond_latents
|
||||
else:
|
||||
deltas = torch.cat((deltas, latents_a - uncond_latents))
|
||||
if latents_b is not None:
|
||||
deltas = torch.cat((deltas, latents_b - uncond_latents))
|
||||
|
||||
# merge the weighted deltas together into a single merged delta
|
||||
per_delta_weights = torch.tensor(weights[1:], dtype=deltas.dtype, device=deltas.device)
|
||||
normalize = False
|
||||
if normalize:
|
||||
per_delta_weights /= torch.sum(per_delta_weights)
|
||||
reshaped_weights = per_delta_weights.reshape(per_delta_weights.shape + (1, 1, 1))
|
||||
deltas_merged = torch.sum(deltas * reshaped_weights, dim=0, keepdim=True)
|
||||
|
||||
# old_return_value = super().forward(x, sigma, uncond, cond, cond_scale)
|
||||
# assert(0 == len(torch.nonzero(old_return_value - (uncond_latents + deltas_merged * cond_scale))))
|
||||
|
||||
return uncond_latents + deltas_merged * global_guidance_scale
|
||||
|
@ -1,5 +1,7 @@
|
||||
from inspect import isfunction
|
||||
import math
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, einsum
|
||||
@ -150,6 +152,7 @@ class SpatialSelfAttention(nn.Module):
|
||||
return x+h_
|
||||
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
||||
super().__init__()
|
||||
@ -170,46 +173,73 @@ class CrossAttention(nn.Module):
|
||||
|
||||
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
|
||||
|
||||
def einsum_op_compvis(self, q, k, v):
|
||||
s = einsum('b i d, b j d -> b i j', q, k)
|
||||
s = s.softmax(dim=-1, dtype=s.dtype)
|
||||
return einsum('b i j, b j d -> b i d', s, v)
|
||||
self.attention_slice_wrangler = None
|
||||
|
||||
def einsum_op_slice_0(self, q, k, v, slice_size):
|
||||
def set_attention_slice_wrangler(self, wrangler:Callable[[nn.Module, torch.Tensor, torch.Tensor, int, int, int], torch.Tensor]):
|
||||
'''
|
||||
Set custom attention calculator to be called when attention is calculated
|
||||
:param wrangler: Callback, with args (self, attention_scores, suggested_attention_slice, dim, offset, slice_size),
|
||||
which returns either the suggested_attention_slice or an adjusted equivalent.
|
||||
self is the current CrossAttention module for which the callback is being invoked.
|
||||
attention_scores are the scores for attention
|
||||
suggested_attention_slice is a softmax(dim=-1) over attention_scores
|
||||
dim is -1 if the call is non-sliced, or 0 or 1 for dimension-0 or dimension-1 slicing.
|
||||
If dim is >= 0, offset and slice_size specify the slice start and length.
|
||||
|
||||
Pass None to use the default attention calculation.
|
||||
:return:
|
||||
'''
|
||||
self.attention_slice_wrangler = wrangler
|
||||
|
||||
def einsum_lowest_level(self, q, k, v, dim, offset, slice_size):
|
||||
# calculate attention scores
|
||||
attention_scores = einsum('b i d, b j d -> b i j', q, k)
|
||||
# calculate attenion slice by taking the best scores for each latent pixel
|
||||
default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype)
|
||||
if self.attention_slice_wrangler is not None:
|
||||
attention_slice = self.attention_slice_wrangler(self, attention_scores, default_attention_slice, dim, offset, slice_size)
|
||||
else:
|
||||
attention_slice = default_attention_slice
|
||||
|
||||
return einsum('b i j, b j d -> b i d', attention_slice, v)
|
||||
|
||||
def einsum_op_slice_dim0(self, q, k, v, slice_size):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[0], slice_size):
|
||||
end = i + slice_size
|
||||
r[i:end] = self.einsum_op_compvis(q[i:end], k[i:end], v[i:end])
|
||||
r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size)
|
||||
return r
|
||||
|
||||
def einsum_op_slice_1(self, q, k, v, slice_size):
|
||||
def einsum_op_slice_dim1(self, q, k, v, slice_size):
|
||||
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
r[:, i:end] = self.einsum_op_compvis(q[:, i:end], k, v)
|
||||
r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size)
|
||||
return r
|
||||
|
||||
def einsum_op_mps_v1(self, q, k, v):
|
||||
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
|
||||
return self.einsum_op_compvis(q, k, v)
|
||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||
else:
|
||||
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
||||
return self.einsum_op_slice_1(q, k, v, slice_size)
|
||||
return self.einsum_op_slice_dim1(q, k, v, slice_size)
|
||||
|
||||
def einsum_op_mps_v2(self, q, k, v):
|
||||
if self.mem_total_gb > 8 and q.shape[1] <= 4096:
|
||||
return self.einsum_op_compvis(q, k, v)
|
||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||
else:
|
||||
return self.einsum_op_slice_0(q, k, v, 1)
|
||||
return self.einsum_op_slice_dim0(q, k, v, 1)
|
||||
|
||||
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb):
|
||||
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
|
||||
if size_mb <= max_tensor_mb:
|
||||
return self.einsum_op_compvis(q, k, v)
|
||||
return self.einsum_lowest_level(q, k, v, None, None, None)
|
||||
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
|
||||
if div <= q.shape[0]:
|
||||
return self.einsum_op_slice_0(q, k, v, q.shape[0] // div)
|
||||
return self.einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
|
||||
print("warning: untested call to einsum_op_slice_dim0")
|
||||
return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div)
|
||||
print("warning: untested call to einsum_op_slice_dim1")
|
||||
return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1))
|
||||
|
||||
def einsum_op_cuda(self, q, k, v):
|
||||
stats = torch.cuda.memory_stats(q.device)
|
||||
@ -221,7 +251,7 @@ class CrossAttention(nn.Module):
|
||||
# Divide factor of safety as there's copying and fragmentation
|
||||
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
|
||||
|
||||
def einsum_op(self, q, k, v):
|
||||
def get_attention_mem_efficient(self, q, k, v):
|
||||
if q.device.type == 'cuda':
|
||||
return self.einsum_op_cuda(q, k, v)
|
||||
|
||||
@ -244,8 +274,13 @@ class CrossAttention(nn.Module):
|
||||
del context, x
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
r = self.einsum_op(q, k, v)
|
||||
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
|
||||
|
||||
r = self.get_attention_mem_efficient(q, k, v)
|
||||
|
||||
hidden_states = rearrange(r, '(b h) n d -> b n (h d)', h=h)
|
||||
return self.to_out(hidden_states)
|
||||
|
||||
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
|
@ -1,3 +1,5 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from functools import partial
|
||||
@ -437,6 +439,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text, **kwargs):
|
||||
|
||||
batch_encoding = self.tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
@ -454,6 +457,222 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
def encode(self, text, **kwargs):
|
||||
return self(text, **kwargs)
|
||||
|
||||
class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
|
||||
|
||||
fragment_weights_key = "fragment_weights"
|
||||
return_tokens_key = "return_tokens"
|
||||
|
||||
def forward(self, text: list, **kwargs):
|
||||
'''
|
||||
|
||||
:param text: A batch of prompt strings, or, a batch of lists of fragments of prompt strings to which different
|
||||
weights shall be applied.
|
||||
:param kwargs: If the keyword arg "fragment_weights" is passed, it shall contain a batch of lists of weights
|
||||
for the prompt fragments. In this case text must contain batches of lists of prompt fragments.
|
||||
:return: A tensor of shape (B, 77, 768) containing weighted embeddings
|
||||
'''
|
||||
if self.fragment_weights_key not in kwargs:
|
||||
# fallback to base class implementation
|
||||
return super().forward(text, **kwargs)
|
||||
|
||||
fragment_weights = kwargs[self.fragment_weights_key]
|
||||
# self.transformer doesn't like receiving "fragment_weights" as an argument
|
||||
kwargs.pop(self.fragment_weights_key)
|
||||
|
||||
should_return_tokens = False
|
||||
if self.return_tokens_key in kwargs:
|
||||
should_return_tokens = kwargs.get(self.return_tokens_key, False)
|
||||
# self.transformer doesn't like having extra kwargs
|
||||
kwargs.pop(self.return_tokens_key)
|
||||
|
||||
batch_z = None
|
||||
batch_tokens = None
|
||||
for fragments, weights in zip(text, fragment_weights):
|
||||
|
||||
# First, weight tokens in individual fragments by scaling the feature vectors as requested (effectively
|
||||
# applying a multiplier to the CFG scale on a per-token basis).
|
||||
# For tokens weighted<1, intuitively we want SD to become not merely *less* interested in the concept
|
||||
# captured by the fragment but actually *dis*interested in it (a 0.01 interest in "red" is still an active
|
||||
# interest, however small, in redness; what the user probably intends when they attach the number 0.01 to
|
||||
# "red" is to tell SD that it should almost completely *ignore* redness).
|
||||
# To do this, the embedding is lerped away from base_embedding in the direction of an embedding for a prompt
|
||||
# string from which the low-weighted fragment has been simply removed. The closer the weight is to zero, the
|
||||
# closer the resulting embedding is to an embedding for a prompt that simply lacks this fragment.
|
||||
|
||||
# handle weights >=1
|
||||
tokens, per_token_weights = self.get_tokens_and_weights(fragments, weights)
|
||||
base_embedding = self.build_weighted_embedding_tensor(tokens, per_token_weights, **kwargs)
|
||||
|
||||
# this is our starting point
|
||||
embeddings = base_embedding.unsqueeze(0)
|
||||
per_embedding_weights = [1.0]
|
||||
|
||||
# now handle weights <1
|
||||
# Do this by building extra embeddings tensors that lack the words being <1 weighted. These will be lerped
|
||||
# with the embeddings tensors that have the words, such that if the weight of a word is 0.5, the resulting
|
||||
# embedding will be exactly half-way between the unweighted prompt and the prompt with the <1 weighted words
|
||||
# removed.
|
||||
# eg for "mountain:1 man:0.5", intuitively the "man" should be "half-gone". therefore, append an embedding
|
||||
# for "mountain" (i.e. without "man") to the already-produced embedding for "mountain man", and weight it
|
||||
# such that the resulting lerped embedding is exactly half-way between "mountain man" and "mountain".
|
||||
for index, fragment_weight in enumerate(weights):
|
||||
if fragment_weight < 1:
|
||||
fragments_without_this = fragments[:index] + fragments[index+1:]
|
||||
weights_without_this = weights[:index] + weights[index+1:]
|
||||
tokens, per_token_weights = self.get_tokens_and_weights(fragments_without_this, weights_without_this)
|
||||
embedding_without_this = self.build_weighted_embedding_tensor(tokens, per_token_weights, **kwargs)
|
||||
|
||||
embeddings = torch.cat((embeddings, embedding_without_this.unsqueeze(0)), dim=1)
|
||||
# weight of the embedding *without* this fragment gets *stronger* as its weight approaches 0
|
||||
# if fragment_weight = 0, basically we want embedding_without_this to completely overwhelm base_embedding
|
||||
# therefore:
|
||||
# fragment_weight = 1: we are at base_z => lerp weight 0
|
||||
# fragment_weight = 0.5: we are halfway between base_z and here => lerp weight 1
|
||||
# fragment_weight = 0: we're now entirely overriding base_z ==> lerp weight inf
|
||||
# so let's use tan(), because:
|
||||
# tan is 0.0 at 0,
|
||||
# 1.0 at PI/4, and
|
||||
# inf at PI/2
|
||||
# -> tan((1-weight)*PI/2) should give us ideal lerp weights
|
||||
epsilon = 1e-9
|
||||
fragment_weight = max(epsilon, fragment_weight) # inf is bad
|
||||
embedding_lerp_weight = math.tan((1.0 - fragment_weight) * math.pi / 2)
|
||||
# todo handle negative weight?
|
||||
|
||||
per_embedding_weights.append(embedding_lerp_weight)
|
||||
|
||||
lerped_embeddings = self.apply_embedding_weights(embeddings, per_embedding_weights, normalize=True).squeeze(0)
|
||||
|
||||
#print(f"assembled tokens for '{fragments}' into tensor of shape {lerped_embeddings.shape}")
|
||||
|
||||
# append to batch
|
||||
batch_z = lerped_embeddings.unsqueeze(0) if batch_z is None else torch.cat([batch_z, lerped_embeddings.unsqueeze(0)], dim=1)
|
||||
batch_tokens = tokens.unsqueeze(0) if batch_tokens is None else torch.cat([batch_tokens, tokens.unsqueeze(0)], dim=1)
|
||||
|
||||
# should have shape (B, 77, 768)
|
||||
#print(f"assembled all tokens into tensor of shape {batch_z.shape}")
|
||||
|
||||
if should_return_tokens:
|
||||
return batch_z, batch_tokens
|
||||
else:
|
||||
return batch_z
|
||||
|
||||
def get_tokens(self, fragments: list[str], include_start_and_end_markers: bool = True) -> list[list[int]]:
|
||||
tokens = self.tokenizer(
|
||||
fragments,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_overflowing_tokens=False,
|
||||
padding='do_not_pad',
|
||||
return_tensors=None, # just give me a list of ints
|
||||
)['input_ids']
|
||||
if include_start_and_end_markers:
|
||||
return tokens
|
||||
else:
|
||||
return [x[1:-1] for x in tokens]
|
||||
|
||||
|
||||
@classmethod
|
||||
def apply_embedding_weights(self, embeddings: torch.Tensor, per_embedding_weights: list[float], normalize:bool) -> torch.Tensor:
|
||||
per_embedding_weights = torch.tensor(per_embedding_weights, dtype=embeddings.dtype, device=embeddings.device)
|
||||
if normalize:
|
||||
per_embedding_weights = per_embedding_weights / torch.sum(per_embedding_weights)
|
||||
reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1, 1,))
|
||||
#reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1,1,)).expand(embeddings.shape)
|
||||
return torch.sum(embeddings * reshaped_weights, dim=1)
|
||||
# lerped embeddings has shape (77, 768)
|
||||
|
||||
|
||||
def get_tokens_and_weights(self, fragments: list[str], weights: list[float]) -> (torch.Tensor, torch.Tensor):
|
||||
'''
|
||||
|
||||
:param fragments:
|
||||
:param weights: Per-fragment weights (CFG scaling). No need for these to be normalized. They will not be normalized here and that's fine.
|
||||
:return:
|
||||
'''
|
||||
# empty is meaningful
|
||||
if len(fragments) == 0 and len(weights) == 0:
|
||||
fragments = ['']
|
||||
weights = [1]
|
||||
item_encodings = self.tokenizer(
|
||||
fragments,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_overflowing_tokens=False,
|
||||
padding='do_not_pad',
|
||||
return_tensors=None, # just give me a list of ints
|
||||
)['input_ids']
|
||||
all_tokens = []
|
||||
per_token_weights = []
|
||||
#print("all fragments:", fragments, weights)
|
||||
for index, fragment in enumerate(item_encodings):
|
||||
weight = weights[index]
|
||||
#print("processing fragment", fragment, weight)
|
||||
fragment_tokens = item_encodings[index]
|
||||
#print("fragment", fragment, "processed to", fragment_tokens)
|
||||
# trim bos and eos markers before appending
|
||||
all_tokens.extend(fragment_tokens[1:-1])
|
||||
per_token_weights.extend([weight] * (len(fragment_tokens) - 2))
|
||||
|
||||
if (len(all_tokens) + 2) > self.max_length:
|
||||
excess_token_count = (len(all_tokens) + 2) - self.max_length
|
||||
print(f"prompt is {excess_token_count} token(s) too long and has been truncated")
|
||||
all_tokens = all_tokens[:self.max_length - 2]
|
||||
|
||||
# pad out to a 77-entry array: [eos_token, <prompt tokens>, eos_token, ..., eos_token]
|
||||
# (77 = self.max_length)
|
||||
pad_length = self.max_length - 1 - len(all_tokens)
|
||||
all_tokens.insert(0, self.tokenizer.bos_token_id)
|
||||
all_tokens.extend([self.tokenizer.eos_token_id] * pad_length)
|
||||
per_token_weights.insert(0, 1)
|
||||
per_token_weights.extend([1] * pad_length)
|
||||
|
||||
all_tokens_tensor = torch.tensor(all_tokens, dtype=torch.long).to(self.device)
|
||||
per_token_weights_tensor = torch.tensor(per_token_weights, dtype=torch.float32).to(self.device)
|
||||
#print(f"assembled all_tokens_tensor with shape {all_tokens_tensor.shape}")
|
||||
return all_tokens_tensor, per_token_weights_tensor
|
||||
|
||||
def build_weighted_embedding_tensor(self, tokens: torch.Tensor, per_token_weights: torch.Tensor, weight_delta_from_empty=True, **kwargs) -> torch.Tensor:
|
||||
'''
|
||||
Build a tensor representing the passed-in tokens, each of which has a weight.
|
||||
:param tokens: A tensor of shape (77) containing token ids (integers)
|
||||
:param per_token_weights: A tensor of shape (77) containing weights (floats)
|
||||
:param method: Whether to multiply the whole feature vector for each token or just its distance from an "empty" feature vector
|
||||
:param kwargs: passed on to self.transformer()
|
||||
:return: A tensor of shape (1, 77, 768) representing the requested weighted embeddings.
|
||||
'''
|
||||
#print(f"building weighted embedding tensor for {tokens} with weights {per_token_weights}")
|
||||
z = self.transformer(input_ids=tokens.unsqueeze(0), **kwargs)
|
||||
batch_weights_expanded = per_token_weights.reshape(per_token_weights.shape + (1,)).expand(z.shape)
|
||||
|
||||
if weight_delta_from_empty:
|
||||
empty_tokens = self.tokenizer([''] * z.shape[0],
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
padding='max_length',
|
||||
return_tensors='pt'
|
||||
)['input_ids'].to(self.device)
|
||||
empty_z = self.transformer(input_ids=empty_tokens, **kwargs)
|
||||
z_delta_from_empty = z - empty_z
|
||||
weighted_z = empty_z + (z_delta_from_empty * batch_weights_expanded)
|
||||
|
||||
weighted_z_delta_from_empty = (weighted_z-empty_z)
|
||||
#print("weighted z has delta from empty with sum", weighted_z_delta_from_empty.sum().item(), "mean", weighted_z_delta_from_empty.mean().item() )
|
||||
|
||||
#print("using empty-delta method, first 5 rows:")
|
||||
#print(weighted_z[:5])
|
||||
|
||||
return weighted_z
|
||||
|
||||
else:
|
||||
original_mean = z.mean()
|
||||
z *= batch_weights_expanded
|
||||
after_weighting_mean = z.mean()
|
||||
# correct the mean. not sure if this is right but it's what the automatic1111 fork of SD does
|
||||
mean_correction_factor = original_mean/after_weighting_mean
|
||||
z *= mean_correction_factor
|
||||
return z
|
||||
|
||||
|
||||
class FrozenCLIPTextEmbedder(nn.Module):
|
||||
"""
|
||||
|
401
tests/test_prompt_parser.py
Normal file
401
tests/test_prompt_parser.py
Normal file
@ -0,0 +1,401 @@
|
||||
import unittest
|
||||
|
||||
import pyparsing
|
||||
|
||||
from ldm.invoke.prompt_parser import PromptParser, Blend, Conjunction, FlattenedPrompt, CrossAttentionControlSubstitute, \
|
||||
Fragment
|
||||
|
||||
|
||||
def parse_prompt(prompt_string):
|
||||
pp = PromptParser()
|
||||
#print(f"parsing '{prompt_string}'")
|
||||
parse_result = pp.parse_conjunction(prompt_string)
|
||||
#print(f"-> parsed '{prompt_string}' to {parse_result}")
|
||||
return parse_result
|
||||
|
||||
def make_basic_conjunction(strings: list[str]):
|
||||
fragments = [Fragment(x) for x in strings]
|
||||
return Conjunction([FlattenedPrompt(fragments)])
|
||||
|
||||
def make_weighted_conjunction(weighted_strings: list[tuple[str,float]]):
|
||||
fragments = [Fragment(x, w) for x,w in weighted_strings]
|
||||
return Conjunction([FlattenedPrompt(fragments)])
|
||||
|
||||
|
||||
class PromptParserTestCase(unittest.TestCase):
|
||||
|
||||
def test_empty(self):
|
||||
self.assertEqual(make_weighted_conjunction([('', 1)]), parse_prompt(''))
|
||||
|
||||
def test_basic(self):
|
||||
self.assertEqual(make_weighted_conjunction([('fire flames', 1)]), parse_prompt("fire (flames)"))
|
||||
self.assertEqual(make_weighted_conjunction([("fire flames", 1)]), parse_prompt("fire flames"))
|
||||
self.assertEqual(make_weighted_conjunction([("fire, flames", 1)]), parse_prompt("fire, flames"))
|
||||
self.assertEqual(make_weighted_conjunction([("fire, flames , fire", 1)]), parse_prompt("fire, flames , fire"))
|
||||
|
||||
def test_attention(self):
|
||||
self.assertEqual(make_weighted_conjunction([('flames', 0.5)]), parse_prompt("(flames)0.5"))
|
||||
self.assertEqual(make_weighted_conjunction([('fire flames', 0.5)]), parse_prompt("(fire flames)0.5"))
|
||||
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("(flames)+"))
|
||||
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("flames+"))
|
||||
self.assertEqual(make_weighted_conjunction([('flames', 1.1)]), parse_prompt("\"flames\"+"))
|
||||
self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("(flames)-"))
|
||||
self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("flames-"))
|
||||
self.assertEqual(make_weighted_conjunction([('flames', 0.9)]), parse_prompt("\"flames\"-"))
|
||||
self.assertEqual(make_weighted_conjunction([('fire', 1), ('flames', 0.5)]), parse_prompt("fire (flames)0.5"))
|
||||
self.assertEqual(make_weighted_conjunction([('flames', pow(1.1, 2))]), parse_prompt("(flames)++"))
|
||||
self.assertEqual(make_weighted_conjunction([('flames', pow(0.9, 2))]), parse_prompt("(flames)--"))
|
||||
self.assertEqual(make_weighted_conjunction([('flowers', pow(0.9, 3)), ('flames', pow(1.1, 3))]), parse_prompt("(flowers)--- flames+++"))
|
||||
self.assertEqual(make_weighted_conjunction([('pretty flowers', 1.1)]),
|
||||
parse_prompt("(pretty flowers)+"))
|
||||
self.assertEqual(make_weighted_conjunction([('pretty flowers', 1.1), (', the flames are too hot', 1)]),
|
||||
parse_prompt("(pretty flowers)+, the flames are too hot"))
|
||||
|
||||
def test_no_parens_attention_runon(self):
|
||||
self.assertEqual(make_weighted_conjunction([('fire', 1.0), ('flames', pow(1.1, 2))]), parse_prompt("fire flames++"))
|
||||
self.assertEqual(make_weighted_conjunction([('fire', 1.0), ('flames', pow(0.9, 2))]), parse_prompt("fire flames--"))
|
||||
self.assertEqual(make_weighted_conjunction([('flowers', 1.0), ('fire', pow(1.1, 2)), ('flames', 1.0)]), parse_prompt("flowers fire++ flames"))
|
||||
self.assertEqual(make_weighted_conjunction([('flowers', 1.0), ('fire', pow(0.9, 2)), ('flames', 1.0)]), parse_prompt("flowers fire-- flames"))
|
||||
|
||||
|
||||
def test_explicit_conjunction(self):
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])]), parse_prompt('("fire", "flames").and(1,1)'))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])]), parse_prompt('("fire", "flames").and()'))
|
||||
self.assertEqual(
|
||||
Conjunction([FlattenedPrompt([('fire flames', 1.0)]), FlattenedPrompt([('mountain man', 1.0)])]), parse_prompt('("fire flames", "mountain man").and()'))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 2.0)]), FlattenedPrompt([('flames', 0.9)])]), parse_prompt('("(fire)2.0", "flames-").and()'))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)]),
|
||||
FlattenedPrompt([('mountain man', 1.0)])]), parse_prompt('("fire", "flames", "mountain man").and()'))
|
||||
|
||||
def test_conjunction_weights(self):
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])], weights=[2.0,1.0]), parse_prompt('("fire", "flames").and(2,1)'))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])], weights=[1.0,2.0]), parse_prompt('("fire", "flames").and(1,2)'))
|
||||
|
||||
with self.assertRaises(PromptParser.ParsingException):
|
||||
parse_prompt('("fire", "flames").and(2)')
|
||||
parse_prompt('("fire", "flames").and(2,1,2)')
|
||||
|
||||
def test_complex_conjunction(self):
|
||||
|
||||
#print(parse_prompt("a person with a hat (riding a bicycle.swap(skateboard))++"))
|
||||
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([("mountain man", 1.0)]), FlattenedPrompt([("a person with a hat", 1.0), ("riding a bicycle", pow(1.1,2))])], weights=[0.5, 0.5]),
|
||||
parse_prompt("(\"mountain man\", \"a person with a hat (riding a bicycle)++\").and(0.5, 0.5)"))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([("mountain man", 1.0)]),
|
||||
FlattenedPrompt([("a person with a hat", 1.0),
|
||||
("riding a", 1.1*1.1),
|
||||
CrossAttentionControlSubstitute(
|
||||
[Fragment("bicycle", pow(1.1,2))],
|
||||
[Fragment("skateboard", pow(1.1,2))])
|
||||
])
|
||||
], weights=[0.5, 0.5]),
|
||||
parse_prompt("(\"mountain man\", \"a person with a hat (riding a bicycle.swap(skateboard))++\").and(0.5, 0.5)"))
|
||||
|
||||
def test_badly_formed(self):
|
||||
def make_untouched_prompt(prompt):
|
||||
return Conjunction([FlattenedPrompt([(prompt, 1.0)])])
|
||||
|
||||
def assert_if_prompt_string_not_untouched(prompt):
|
||||
self.assertEqual(make_untouched_prompt(prompt), parse_prompt(prompt))
|
||||
|
||||
assert_if_prompt_string_not_untouched('a test prompt')
|
||||
# todo handle this
|
||||
#assert_if_prompt_string_not_untouched('a badly formed +test prompt')
|
||||
with self.assertRaises(pyparsing.ParseException):
|
||||
parse_prompt('a badly (formed test prompt')
|
||||
#with self.assertRaises(pyparsing.ParseException):
|
||||
with self.assertRaises(pyparsing.ParseException):
|
||||
parse_prompt('a badly (formed +test prompt')
|
||||
with self.assertRaises(pyparsing.ParseException):
|
||||
parse_prompt('a badly (formed +test )prompt')
|
||||
with self.assertRaises(pyparsing.ParseException):
|
||||
parse_prompt('a badly (formed +test )prompt')
|
||||
with self.assertRaises(pyparsing.ParseException):
|
||||
parse_prompt('(((a badly (formed +test )prompt')
|
||||
with self.assertRaises(pyparsing.ParseException):
|
||||
parse_prompt('(a (ba)dly (f)ormed +test prompt')
|
||||
with self.assertRaises(pyparsing.ParseException):
|
||||
parse_prompt('(a (ba)dly (f)ormed +test +prompt')
|
||||
with self.assertRaises(pyparsing.ParseException):
|
||||
parse_prompt('("((a badly (formed +test ").blend(1.0)')
|
||||
|
||||
|
||||
def test_blend(self):
|
||||
self.assertEqual(Conjunction(
|
||||
[Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('fire flames', 1.0)])], [0.7, 0.3])]),
|
||||
parse_prompt("(\"fire\", \"fire flames\").blend(0.7, 0.3)")
|
||||
)
|
||||
self.assertEqual(Conjunction([Blend(
|
||||
[FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('fire flames', 1.0)]), FlattenedPrompt([('hi', 1.0)])],
|
||||
[0.7, 0.3, 1.0])]),
|
||||
parse_prompt("(\"fire\", \"fire flames\", \"hi\").blend(0.7, 0.3, 1.0)")
|
||||
)
|
||||
self.assertEqual(Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]),
|
||||
FlattenedPrompt([('fire flames', 1.0), ('hot', pow(1.1, 2))]),
|
||||
FlattenedPrompt([('hi', 1.0)])],
|
||||
weights=[0.7, 0.3, 1.0])]),
|
||||
parse_prompt("(\"fire\", \"fire flames (hot)++\", \"hi\").blend(0.7, 0.3, 1.0)")
|
||||
)
|
||||
# blend a single entry is not a failure
|
||||
self.assertEqual(Conjunction([Blend([FlattenedPrompt([('fire', 1.0)])], [0.7])]),
|
||||
parse_prompt("(\"fire\").blend(0.7)")
|
||||
)
|
||||
# blend with empty
|
||||
self.assertEqual(
|
||||
Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]),
|
||||
parse_prompt("(\"fire\", \"\").blend(0.7, 1)")
|
||||
)
|
||||
self.assertEqual(
|
||||
Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]),
|
||||
parse_prompt("(\"fire\", \" \").blend(0.7, 1)")
|
||||
)
|
||||
self.assertEqual(
|
||||
Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]),
|
||||
parse_prompt("(\"fire\", \" \").blend(0.7, 1)")
|
||||
)
|
||||
self.assertEqual(
|
||||
Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([(',', 1.0)])], [0.7, 1.0])]),
|
||||
parse_prompt("(\"fire\", \" , \").blend(0.7, 1)")
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
Conjunction([Blend([FlattenedPrompt([('mountain, man, hairy', 1)]),
|
||||
FlattenedPrompt([('face, teeth,', 1), ('eyes', 0.9*0.9)])], weights=[1.0,-1.0])]),
|
||||
parse_prompt('("mountain, man, hairy", "face, teeth, eyes--").blend(1,-1)')
|
||||
)
|
||||
|
||||
|
||||
def test_nested(self):
|
||||
self.assertEqual(make_weighted_conjunction([('fire', 1.0), ('flames', 2.0), ('trees', 3.0)]),
|
||||
parse_prompt('fire (flames (trees)1.5)2.0'))
|
||||
self.assertEqual(Conjunction([Blend(prompts=[FlattenedPrompt([('fire', 1.0), ('flames', 1.2100000000000002)]),
|
||||
FlattenedPrompt([('mountain', 1.0), ('man', 2.0)])],
|
||||
weights=[1.0, 1.0])]),
|
||||
parse_prompt('("fire (flames)++", "mountain (man)2").blend(1,1)'))
|
||||
|
||||
def test_cross_attention_control(self):
|
||||
|
||||
self.assertEqual(Conjunction([
|
||||
FlattenedPrompt([Fragment('a', 1),
|
||||
CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)]),
|
||||
Fragment('eating a hotdog', 1)])]), parse_prompt("a \"cat\".swap(dog) eating a hotdog"))
|
||||
|
||||
self.assertEqual(Conjunction([
|
||||
FlattenedPrompt([Fragment('a', 1),
|
||||
CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)]),
|
||||
Fragment('eating a hotdog', 1)])]), parse_prompt("a cat.swap(dog) eating a hotdog"))
|
||||
|
||||
|
||||
fire_flames_to_trees = Conjunction([FlattenedPrompt([('fire', 1.0), \
|
||||
CrossAttentionControlSubstitute([Fragment('flames', 1)], [Fragment('trees', 1)])])])
|
||||
self.assertEqual(fire_flames_to_trees, parse_prompt('fire "flames".swap(trees)'))
|
||||
self.assertEqual(fire_flames_to_trees, parse_prompt('fire (flames).swap(trees)'))
|
||||
self.assertEqual(fire_flames_to_trees, parse_prompt('fire ("flames").swap(trees)'))
|
||||
self.assertEqual(fire_flames_to_trees, parse_prompt('fire "flames".swap("trees")'))
|
||||
self.assertEqual(fire_flames_to_trees, parse_prompt('fire (flames).swap("trees")'))
|
||||
self.assertEqual(fire_flames_to_trees, parse_prompt('fire ("flames").swap("trees")'))
|
||||
|
||||
fire_flames_to_trees_and_houses = Conjunction([FlattenedPrompt([('fire', 1.0), \
|
||||
CrossAttentionControlSubstitute([Fragment('flames', 1)], [Fragment('trees and houses', 1)])])])
|
||||
self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire ("flames").swap("trees and houses")'))
|
||||
self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire (flames).swap("trees and houses")'))
|
||||
self.assertEqual(fire_flames_to_trees_and_houses, parse_prompt('fire "flames".swap("trees and houses")'))
|
||||
|
||||
trees_and_houses_to_flames = Conjunction([FlattenedPrompt([('fire', 1.0), \
|
||||
CrossAttentionControlSubstitute([Fragment('trees and houses', 1)], [Fragment('flames',1)])])])
|
||||
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire ("trees and houses").swap("flames")'))
|
||||
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire (trees and houses).swap("flames")'))
|
||||
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire "trees and houses".swap("flames")'))
|
||||
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire ("trees and houses").swap(flames)'))
|
||||
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire (trees and houses).swap(flames)'))
|
||||
self.assertEqual(trees_and_houses_to_flames, parse_prompt('fire "trees and houses".swap(flames)'))
|
||||
|
||||
flames_to_trees_fire = Conjunction([FlattenedPrompt([
|
||||
CrossAttentionControlSubstitute([Fragment('flames',1)], [Fragment('trees',1)]),
|
||||
(', fire', 1.0)])])
|
||||
self.assertEqual(flames_to_trees_fire, parse_prompt('"flames".swap("trees"), fire'))
|
||||
self.assertEqual(flames_to_trees_fire, parse_prompt('(flames).swap("trees"), fire'))
|
||||
self.assertEqual(flames_to_trees_fire, parse_prompt('("flames").swap("trees"), fire'))
|
||||
self.assertEqual(flames_to_trees_fire, parse_prompt('"flames".swap(trees), fire'))
|
||||
self.assertEqual(flames_to_trees_fire, parse_prompt('(flames).swap(trees), fire '))
|
||||
self.assertEqual(flames_to_trees_fire, parse_prompt('("flames").swap(trees), fire '))
|
||||
|
||||
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
|
||||
CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]),
|
||||
parse_prompt('a forest landscape "".swap("in winter")'))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
|
||||
CrossAttentionControlSubstitute([Fragment('',1)], [Fragment('in winter',1)])])]),
|
||||
parse_prompt('a forest landscape " ".swap("in winter")'))
|
||||
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
|
||||
CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]),
|
||||
parse_prompt('a forest landscape "in winter".swap("")'))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
|
||||
CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]),
|
||||
parse_prompt('a forest landscape "in winter".swap()'))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([Fragment('a forest landscape', 1),
|
||||
CrossAttentionControlSubstitute([Fragment('in winter',1)], [Fragment('',1)])])]),
|
||||
parse_prompt('a forest landscape "in winter".swap(" ")'))
|
||||
|
||||
def test_cross_attention_control_with_attention(self):
|
||||
flames_to_trees_fire = Conjunction([FlattenedPrompt([
|
||||
CrossAttentionControlSubstitute([Fragment('flames',0.5)], [Fragment('trees',0.7)]),
|
||||
Fragment(',', 1), Fragment('fire', 2.0)])])
|
||||
self.assertEqual(flames_to_trees_fire, parse_prompt('"(flames)0.5".swap("(trees)0.7"), (fire)2.0'))
|
||||
flames_to_trees_fire = Conjunction([FlattenedPrompt([
|
||||
CrossAttentionControlSubstitute([Fragment('fire',0.5), Fragment('flames',0.25)], [Fragment('trees',0.7)]),
|
||||
Fragment(',', 1), Fragment('fire', 2.0)])])
|
||||
self.assertEqual(flames_to_trees_fire, parse_prompt('"(fire (flames)0.5)0.5".swap("(trees)0.7"), (fire)2.0'))
|
||||
flames_to_trees_fire = Conjunction([FlattenedPrompt([
|
||||
CrossAttentionControlSubstitute([Fragment('fire',0.5), Fragment('flames',0.25)], [Fragment('trees',0.7), Fragment('houses', 1)]),
|
||||
Fragment(',', 1), Fragment('fire', 2.0)])])
|
||||
self.assertEqual(flames_to_trees_fire, parse_prompt('"(fire (flames)0.5)0.5".swap("(trees)0.7 houses"), (fire)2.0'))
|
||||
|
||||
def test_cross_attention_control_options(self):
|
||||
self.assertEqual(Conjunction([
|
||||
FlattenedPrompt([Fragment('a', 1),
|
||||
CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)], options={'s_start':0.1}),
|
||||
Fragment('eating a hotdog', 1)])]),
|
||||
parse_prompt("a \"cat\".swap(dog, s_start=0.1) eating a hotdog"))
|
||||
self.assertEqual(Conjunction([
|
||||
FlattenedPrompt([Fragment('a', 1),
|
||||
CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)], options={'t_start':0.1}),
|
||||
Fragment('eating a hotdog', 1)])]),
|
||||
parse_prompt("a \"cat\".swap(dog, t_start=0.1) eating a hotdog"))
|
||||
self.assertEqual(Conjunction([
|
||||
FlattenedPrompt([Fragment('a', 1),
|
||||
CrossAttentionControlSubstitute([Fragment('cat', 1)], [Fragment('dog', 1)], options={'s_start': 20.0, 't_start':0.1}),
|
||||
Fragment('eating a hotdog', 1)])]),
|
||||
parse_prompt("a \"cat\".swap(dog, t_start=0.1, s_start=20) eating a hotdog"))
|
||||
|
||||
self.assertEqual(
|
||||
Conjunction([
|
||||
FlattenedPrompt([Fragment('a fantasy forest landscape', 1),
|
||||
CrossAttentionControlSubstitute([Fragment('', 1)], [Fragment('with a river', 1)],
|
||||
options={'s_start': 0.8, 't_start': 0.8})])]),
|
||||
parse_prompt("a fantasy forest landscape \"\".swap(with a river, s_start=0.8, t_start=0.8)"))
|
||||
|
||||
|
||||
def test_escaping(self):
|
||||
|
||||
# make sure ", ( and ) can be escaped
|
||||
|
||||
self.assertEqual(make_basic_conjunction(['mountain (man)']),parse_prompt('mountain \(man\)'))
|
||||
self.assertEqual(make_basic_conjunction(['mountain (man )']),parse_prompt('mountain (\(man)\)'))
|
||||
self.assertEqual(make_basic_conjunction(['mountain (man)']),parse_prompt('mountain (\(man\))'))
|
||||
self.assertEqual(make_weighted_conjunction([('mountain', 1), ('(man)', 1.1)]), parse_prompt('mountain (\(man\))+'))
|
||||
self.assertEqual(make_weighted_conjunction([('mountain', 1), ('(man)', 1.1)]), parse_prompt('"mountain" (\(man\))+'))
|
||||
self.assertEqual(make_weighted_conjunction([('"mountain"', 1), ('(man)', 1.1)]), parse_prompt('\\"mountain\\" (\(man\))+'))
|
||||
# same weights for each are combined into one
|
||||
self.assertEqual(make_weighted_conjunction([('"mountain" (man)', 1.1)]), parse_prompt('(\\"mountain\\")+ (\(man\))+'))
|
||||
self.assertEqual(make_weighted_conjunction([('"mountain"', 1.1), ('(man)', 0.9)]), parse_prompt('(\\"mountain\\")+ (\(man\))-'))
|
||||
|
||||
self.assertEqual(make_weighted_conjunction([('mountain', 1), ('\(man\)', 1.1)]),parse_prompt('mountain (\(man\))1.1'))
|
||||
self.assertEqual(make_weighted_conjunction([('mountain', 1), ('\(man\)', 1.1)]),parse_prompt('"mountain" (\(man\))1.1'))
|
||||
self.assertEqual(make_weighted_conjunction([('"mountain"', 1), ('\(man\)', 1.1)]),parse_prompt('\\"mountain\\" (\(man\))1.1'))
|
||||
# same weights for each are combined into one
|
||||
self.assertEqual(make_weighted_conjunction([('\\"mountain\\" \(man\)', 1.1)]),parse_prompt('(\\"mountain\\")+ (\(man\))1.1'))
|
||||
self.assertEqual(make_weighted_conjunction([('\\"mountain\\"', 1.1), ('\(man\)', 0.9)]),parse_prompt('(\\"mountain\\")1.1 (\(man\))0.9'))
|
||||
|
||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain', 1.1), ('\(man\)', 1.1*1.1)]),parse_prompt('hairy (mountain (\(man\))+)+'))
|
||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('\(man\)', 1.1*1.1), ('mountain', 1.1)]),parse_prompt('hairy ((\(man\))1.1 "mountain")+'))
|
||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain', 1.1), ('\(man\)', 1.1*1.1)]),parse_prompt('hairy ("mountain" (\(man\))1.1 )+'))
|
||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, man', 1.1)]),parse_prompt('hairy ("mountain, man")+'))
|
||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, man with a', 1.1), ('beard', 1.1*1.1)]), parse_prompt('hairy ("mountain, man" with a beard+)+'))
|
||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, man with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, man" with a (beard)2.0)+'))
|
||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man\" with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\"man\\"" with a (beard)2.0)+'))
|
||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, m\"an\" with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, m\\"an\\"" with a (beard)2.0)+'))
|
||||
|
||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man (with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" \(with a (beard)2.0)+'))
|
||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man w(ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" w\(ith a (beard)2.0)+'))
|
||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man with( a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" with\( a (beard)2.0)+'))
|
||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man )with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" \)with a (beard)2.0)+'))
|
||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" w\)ith a (beard)2.0)+'))
|
||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mountain, \"man with) a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mountain, \\\"man\" with\) a (beard)2.0)+'))
|
||||
self.assertEqual(make_weighted_conjunction([('hairy', 1), ('mou)ntain, \"man (wit(h a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy ("mou\)ntain, \\\"man\" \(wit\(h a (beard)2.0)+'))
|
||||
self.assertEqual(make_weighted_conjunction([('hai(ry', 1), ('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hai\(ry ("mountain, \\\"man\" w\)ith a (beard)2.0)+'))
|
||||
self.assertEqual(make_weighted_conjunction([('hairy((', 1), ('mountain, \"man with a', 1.1), ('beard', 1.1*2.0)]), parse_prompt('hairy\(\( ("mountain, \\\"man\" with a (beard)2.0)+'))
|
||||
|
||||
self.assertEqual(make_weighted_conjunction([('mountain, \"man (with a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" \(with a (beard)2.0)+ hairy'))
|
||||
self.assertEqual(make_weighted_conjunction([('mountain, \"man w(ith a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" w\(ith a (beard)2.0)+hairy'))
|
||||
self.assertEqual(make_weighted_conjunction([('mountain, \"man with( a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" with\( a (beard)2.0)+ hairy'))
|
||||
self.assertEqual(make_weighted_conjunction([('mountain, \"man )with a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" \)with a (beard)2.0)+ hairy'))
|
||||
self.assertEqual(make_weighted_conjunction([('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mountain, \\\"man\" w\)ith a (beard)2.0)+ hairy'))
|
||||
self.assertEqual(make_weighted_conjunction([('mountain, \"man with) a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt(' ("mountain, \\\"man\" with\) a (beard)2.0)+ hairy'))
|
||||
self.assertEqual(make_weighted_conjunction([('mou)ntain, \"man (wit(h a', 1.1), ('beard', 1.1*2.0), ('hairy', 1)]), parse_prompt('("mou\)ntain, \\\"man\" \(wit\(h a (beard)2.0)+ hairy'))
|
||||
self.assertEqual(make_weighted_conjunction([('mountain, \"man w)ith a', 1.1), ('beard', 1.1*2.0), ('hai(ry', 1)]), parse_prompt('("mountain, \\\"man\" w\)ith a (beard)2.0)+ hai\(ry '))
|
||||
self.assertEqual(make_weighted_conjunction([('mountain, \"man with a', 1.1), ('beard', 1.1*2.0), ('hairy((', 1)]), parse_prompt('("mountain, \\\"man\" with a (beard)2.0)+ hairy\(\( '))
|
||||
|
||||
def test_cross_attention_escaping(self):
|
||||
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('monkey', 1)])])]),
|
||||
parse_prompt('mountain (man).swap(monkey)'))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('m(onkey', 1)])])]),
|
||||
parse_prompt('mountain (man).swap(m\(onkey)'))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('m(an', 1)], [Fragment('m(onkey', 1)])])]),
|
||||
parse_prompt('mountain (m\(an).swap(m\(onkey)'))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('(((', 1)], [Fragment('m(on))key', 1)])])]),
|
||||
parse_prompt('mountain (\(\(\().swap(m\(on\)\)key)'))
|
||||
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('monkey', 1)])])]),
|
||||
parse_prompt('mountain ("man").swap(monkey)'))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('monkey', 1)])])]),
|
||||
parse_prompt('mountain ("man").swap("monkey")'))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('"man', 1)], [Fragment('monkey', 1)])])]),
|
||||
parse_prompt('mountain (\\"man).swap("monkey")'))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('man', 1)], [Fragment('m(onkey', 1)])])]),
|
||||
parse_prompt('mountain (man).swap(m\(onkey)'))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('m(an', 1)], [Fragment('m(onkey', 1)])])]),
|
||||
parse_prompt('mountain (m\(an).swap(m\(onkey)'))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('(((', 1)], [Fragment('m(on))key', 1)])])]),
|
||||
parse_prompt('mountain (\(\(\().swap(m\(on\)\)key)'))
|
||||
|
||||
def test_legacy_blend(self):
|
||||
pp = PromptParser()
|
||||
|
||||
self.assertEqual(Blend([FlattenedPrompt([('mountain man', 1)]),
|
||||
FlattenedPrompt([('man mountain', 1)])],
|
||||
weights=[0.5,0.5]),
|
||||
pp.parse_legacy_blend('mountain man:1 man mountain:1'))
|
||||
|
||||
self.assertEqual(Blend([FlattenedPrompt([('mountain', 1.1), ('man', 1)]),
|
||||
FlattenedPrompt([('man', 1), ('mountain', 0.9)])],
|
||||
weights=[0.5,0.5]),
|
||||
pp.parse_legacy_blend('mountain+ man:1 man mountain-:1'))
|
||||
|
||||
self.assertEqual(Blend([FlattenedPrompt([('mountain', 1.1), ('man', 1)]),
|
||||
FlattenedPrompt([('man', 1), ('mountain', 0.9)])],
|
||||
weights=[0.5,0.5]),
|
||||
pp.parse_legacy_blend('mountain+ man:1 man mountain-'))
|
||||
|
||||
self.assertEqual(Blend([FlattenedPrompt([('mountain', 1.1), ('man', 1)]),
|
||||
FlattenedPrompt([('man', 1), ('mountain', 0.9)])],
|
||||
weights=[0.5,0.5]),
|
||||
pp.parse_legacy_blend('mountain+ man: man mountain-:'))
|
||||
|
||||
self.assertEqual(Blend([FlattenedPrompt([('mountain man', 1)]),
|
||||
FlattenedPrompt([('man mountain', 1)])],
|
||||
weights=[0.75,0.25]),
|
||||
pp.parse_legacy_blend('mountain man:3 man mountain:1'))
|
||||
|
||||
self.assertEqual(Blend([FlattenedPrompt([('mountain man', 1)]),
|
||||
FlattenedPrompt([('man mountain', 1)])],
|
||||
weights=[1.0,0.0]),
|
||||
pp.parse_legacy_blend('mountain man:3 man mountain:0'))
|
||||
|
||||
self.assertEqual(Blend([FlattenedPrompt([('mountain man', 1)]),
|
||||
FlattenedPrompt([('man mountain', 1)])],
|
||||
weights=[0.8,0.2]),
|
||||
pp.parse_legacy_blend('"mountain man":4 man mountain'))
|
||||
|
||||
|
||||
def test_single(self):
|
||||
# todo handle this
|
||||
#self.assertEqual(make_basic_conjunction(['a badly formed +test prompt']),
|
||||
# parse_prompt('a badly formed +test prompt'))
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in New Issue
Block a user