Squashed commit of the following:

commit 9bb0b5d0036c4dffbb72ce11e097fae4ab63defd
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Sat Oct 15 23:43:41 2022 +0200

    undo local_files_only stuff

commit eed93f5d30c34cfccaf7497618ae9af17a5ecfbb
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Sat Oct 15 23:40:37 2022 +0200

    Revert "Merge branch 'development-invoke' into fix-prompts"

    This reverts commit 7c40892a9f184f7e216f14d14feb0411c5a90e24, reversing
    changes made to e3f2dd62b0548ca6988818ef058093a4f5b022f2.

commit f06d6024e345c69e6d5a91ab5423925a68ee95a7
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Thu Oct 13 23:30:16 2022 +0200

    more efficiently handle multiple conditioning

commit 5efdfcbcd980ce6202ab74e7f90e7415ce7260da
Merge: b9c0dc5 ac08bb6
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Thu Oct 13 14:51:01 2022 +0200

    Merge branch 'optional-disable-karras-schedule' into fix-prompts

commit ac08bb6fd25e19a9d35cf6c199e66500fb604af1
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Thu Oct 13 14:50:43 2022 +0200

    append '*use_model_sigmas*' to prompt string to use model sigmas

commit 70d8c05a3ff329409f76204f4af94e55d468ab8b
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Thu Oct 13 12:12:17 2022 +0200

    make karras scheduling switchable

    commit d60df54f69968e2fb22809c55e23b3c02f37ad63 replaced the model's
    own scheduling with karras scheduling. this has changed image generation
    (seems worse now?)

    this commit wraps the change in a bool.

commit b9c0dc5f1a658a0e6c3936000e9ae559e1c7a1db
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Wed Oct 12 20:16:00 2022 +0200

    add test of more complex conjunction

commit 9ac0c15cc0d7b5f6df3289d3ad474260972a17be
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Wed Oct 12 17:18:25 2022 +0200

    improve comments

commit ad33bce60590b87b2a93e90f16dc9d3e935d04a5
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Wed Oct 12 17:04:46 2022 +0200

    put back thresholding stuff

commit 4852c698a325049834ba0d4b358f07210bc7171a
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Wed Oct 12 14:25:02 2022 +0200

    notes on improving conjunction efficiency

commit a53bb1e5b68025d09642b935ae6a9a015cfaf2d6
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Wed Oct 12 14:14:33 2022 +0200

    optional weights support for Conjunction

commit fec79ab15e4f0c84dd61cb1b45a5e6a72ae4aaeb
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Wed Oct 12 12:07:27 2022 +0200

    fix blend error and log parsing output

commit 1f751c2a039f9c97af57b18e0f019512631d5a25
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Wed Oct 12 10:33:33 2022 +0200

    fix broken euler sampler

commit 02f8148d17efe4b6bde8d29b827092a0626363ee
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Wed Oct 12 10:24:20 2022 +0200

    cleanup prompt parser

commit 8028d49ae6c16c0d6ec9c9de9c12d56c32201421
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Wed Oct 12 10:14:18 2022 +0200

    explicit conjunction, improve flattening logic

commit 8a1710892185f07eb77483f7edae0fc4d6bbb250
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Tue Oct 11 22:59:30 2022 +0200

    adapt multi-conditioning to also work with ddim

commit 53802a839850d0d1ff017c6bafe457c4bed750b0
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Tue Oct 11 22:31:42 2022 +0200

    unconditioning is also fancy-prompt-syntaxable

commit 7c40892a9f184f7e216f14d14feb0411c5a90e24
Merge: e3f2dd6 dbe0da4
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Tue Oct 11 21:39:54 2022 +0200

    Merge branch 'development-invoke' into fix-prompts

commit e3f2dd62b0548ca6988818ef058093a4f5b022f2
Merge: eef0e48 06f542e
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Tue Oct 11 21:38:09 2022 +0200

    Merge remote-tracking branch 'upstream/development' into fix-prompts

commit eef0e484c2eaa1bd4e0e0b1d3f8d7bba38478144
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Tue Oct 11 21:26:25 2022 +0200

    fix run-on paren-less attention, add some comments

commit fd29afdf0e9f5e0cdc60239e22480c36ca0aaeca
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Tue Oct 11 21:03:02 2022 +0200

    python 3.9 compatibility

commit 26f7646eef7f39bc8f7ce805e747df0f723464da
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Tue Oct 11 20:58:42 2022 +0200

    first pass connecting PromptParser to conditioning

commit ae53dff3796d7b9a5e7ed30fa1edb0374af6cd8d
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Tue Oct 11 20:51:15 2022 +0200

    update frontend dist

commit 9be4a59a2d76f49e635474b5984bfca826a5dab4
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Tue Oct 11 19:01:39 2022 +0200

    fix issues with correctness checking FlattenedPrompt

commit 3be212323eab68e72a363a654124edd9809e4cf0
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Tue Oct 11 18:43:16 2022 +0200

    parsing nested seems to work pretty ok

commit acd73eb08cf67c27cac8a22934754321256f56a9
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Tue Oct 11 18:26:17 2022 +0200

    wip introducing FlattenedPrompt class

commit 71698d5c7c2ac855b690d8ef67e8830148c59eda
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Tue Oct 11 15:59:42 2022 +0200

    recursive attention weighting seems to actually work

commit a4e1ec6b20deb7cc0cd12737bdbd266e56144709
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Tue Oct 11 15:06:24 2022 +0200

    now apparently almost supported nested attention

commit da76fd1ddf22a3888cdc08fd4fed38d8b178e524
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Tue Oct 11 13:23:37 2022 +0200

    wip prompt parsing

commit dbe0da4572c2ac22f26a7afd722349a5680a9e47
Author: Kyle Schouviller <kyle0654@hotmail.com>
Date:   Mon Oct 10 22:32:35 2022 -0700

    Adding node-based invocation apps

commit 8f2a2ffc083366de74d7dae471b50b6f98a7c5f8
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Mon Oct 10 19:03:18 2022 +0200

    fix merge issues

commit 73118dee2a8f4891700756e014caf1c9ca629267
Merge: fd00844 12413b0
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Mon Oct 10 12:42:48 2022 +0200

    Merge remote-tracking branch 'upstream/development' into fix-prompts

commit fd0084413541013c2cf71e006af0392719bef53d
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Mon Oct 10 12:39:38 2022 +0200

    wip prompt parsing

commit 0be9363db9307859d2b65cffc6af01f57d7873a4
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Mon Oct 10 03:20:06 2022 +0200

    better +/- attention parsing

commit 5383f691874a58ab01cda1e4fac6cf330146526a
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Mon Oct 10 02:27:47 2022 +0200

    prompt parser seems to work

commit 591d098a33ce35462428d8c169501d8ed73615ab
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Sun Oct 9 20:25:37 2022 +0200

    supports weighting unconditioning, cross-attention with |

commit 7a7220563aa05a2980235b5b908362f66b728309
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Sun Oct 9 18:15:56 2022 +0200

    i think cross attention might be working?

commit 951ed391e7126bff228c18b2db304ad28d59644a
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Sun Oct 9 16:04:54 2022 +0200

    weighted CFG denoiser working with a single item

commit ee532a0c2827368c9e45a6a5f3975666402873da
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Sun Oct 9 06:33:40 2022 +0200

    wip probably doesn't work or compile

commit 14654bcbd207b9ca28a6cbd37dbd967d699b062d
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Fri Oct 7 18:11:48 2022 +0200

    use tan() to calculate embedding weight for <1 attentions

commit 1a8e76b31aa5abf5150419ebf3b29d4658d07f2b
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Fri Oct 7 16:14:54 2022 +0200

    fix bad math.max reference

commit f697ff896875876ccaa1e5527405bdaa7ed27cde
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Fri Oct 7 15:55:57 2022 +0200

    respect http[s]x protocol when making socket.io middleware

commit 41d3dd4eeae8d4efb05dfb44fc6d8aac5dc468ab
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Fri Oct 7 13:29:54 2022 +0200

    fractional weighting works, by blending with prompts excluding the word

commit 087fb6dfb3e8f5e84de8c911f75faa3e3fa3553c
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Fri Oct 7 10:52:03 2022 +0200

    wip doing weights <1 by averaging with conditioning absent the lower-weighted fragment

commit 3c49e3f3ec7c18dc60f3e18ed2f7f0d97aad3a47
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Fri Oct 7 10:36:15 2022 +0200

    notate CFGDenoiser, perhaps

commit d2bcf1bb522026ebf209ad0103f6b370383e5070
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Thu Oct 6 05:04:47 2022 +0200

    hack blending syntax to test attention weighting more extensively

commit 94904ef2cf917f74ec23ef7a570e12ff8255b048
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Thu Oct 6 04:56:37 2022 +0200

    conditioning works, apparently

commit 7c6663ddd70f665fd1308b6dd74f92ca393a8df5
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Thu Oct 6 02:20:24 2022 +0200

    attention weighting, definitely works in positive direction

commit 5856d453a9b020bc1a28ff643ae1f58c12c9be73
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Tue Oct 4 19:02:14 2022 +0200

    wip bubbling weights down

commit a2ed14fd9b7d3cb36b6c5348018b364c76d1e892
Author: Damian at mba <damian@frey.NOSPAMco.nz>
Date:   Tue Oct 4 17:35:39 2022 +0200

    bring in changes from PC
This commit is contained in:
Damian at mba 2022-10-15 23:44:54 +02:00
parent 92d4dfaabf
commit c3b992db96
11 changed files with 823 additions and 43 deletions

View File

@ -527,7 +527,7 @@ def parameters_to_generated_image_metadata(parameters):
rfc_dict["sampler"] = parameters["sampler_name"] rfc_dict["sampler"] = parameters["sampler_name"]
# display weighted subprompts (liable to change) # display weighted subprompts (liable to change)
subprompts = split_weighted_subprompts(parameters["prompt"]) subprompts = split_weighted_subprompts(parameters["prompt"], skip_normalize=True)
subprompts = [{"prompt": x[0], "weight": x[1]} for x in subprompts] subprompts = [{"prompt": x[0], "weight": x[1]} for x in subprompts]
rfc_dict["prompt"] = subprompts rfc_dict["prompt"] = subprompts

View File

@ -76,4 +76,4 @@ model:
target: torch.nn.Identity target: torch.nn.Identity
cond_stage_config: cond_stage_config:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder target: ldm.modules.encoders.modules.WeightedFrozenCLIPEmbedder

View File

@ -438,7 +438,7 @@ class Generate:
sampler=self.sampler, sampler=self.sampler,
steps=steps, steps=steps,
cfg_scale=cfg_scale, cfg_scale=cfg_scale,
conditioning=(uc, c), conditioning=(uc, c), # here change to arrays
ddim_eta=ddim_eta, ddim_eta=ddim_eta,
image_callback=image_callback, # called after the final image is generated image_callback=image_callback, # called after the final image is generated
step_callback=step_callback, # called after each intermediate image is generated step_callback=step_callback, # called after each intermediate image is generated
@ -477,6 +477,10 @@ class Generate:
print('**Interrupted** Partial results will be returned.') print('**Interrupted** Partial results will be returned.')
else: else:
raise KeyboardInterrupt raise KeyboardInterrupt
# brute-force fallback
except Exception as e:
print(traceback.format_exc(), file=sys.stderr)
print('>> Could not generate image.')
toc = time.time() toc = time.time()
print('>> Usage stats:') print('>> Usage stats:')

View File

@ -12,42 +12,76 @@ log_tokenization() print out colour-coded tokens and warn if trunca
import re import re
import torch import torch
def get_uc_and_c(prompt, model, log_tokens=False, skip_normalize=False): from .prompt_parser import PromptParser, Fragment, Attention, Blend, Conjunction, FlattenedPrompt
from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder
def get_uc_and_c(prompt_string_uncleaned, model, log_tokens=False, skip_normalize=False):
# Extract Unconditioned Words From Prompt # Extract Unconditioned Words From Prompt
unconditioned_words = '' unconditioned_words = ''
unconditional_regex = r'\[(.*?)\]' unconditional_regex = r'\[(.*?)\]'
unconditionals = re.findall(unconditional_regex, prompt) unconditionals = re.findall(unconditional_regex, prompt_string_uncleaned)
if len(unconditionals) > 0: if len(unconditionals) > 0:
unconditioned_words = ' '.join(unconditionals) unconditioned_words = ' '.join(unconditionals)
# Remove Unconditioned Words From Prompt # Remove Unconditioned Words From Prompt
unconditional_regex_compile = re.compile(unconditional_regex) unconditional_regex_compile = re.compile(unconditional_regex)
clean_prompt = unconditional_regex_compile.sub(' ', prompt) clean_prompt = unconditional_regex_compile.sub(' ', prompt_string_uncleaned)
prompt = re.sub(' +', ' ', clean_prompt) prompt_string_cleaned = re.sub(' +', ' ', clean_prompt)
else:
prompt_string_cleaned = prompt_string_uncleaned
uc = model.get_learned_conditioning([unconditioned_words]) pp = PromptParser()
# get weighted sub-prompts def build_conditioning_list(prompt_string:str):
weighted_subprompts = split_weighted_subprompts( parsed_conjunction: Conjunction = pp.parse(prompt_string)
prompt, skip_normalize print(f"parsed '{prompt_string}' to {parsed_conjunction}")
) assert (type(parsed_conjunction) is Conjunction)
conditioning_list = []
def make_embeddings_for_flattened_prompt(flattened_prompt: FlattenedPrompt):
if type(flattened_prompt) is not FlattenedPrompt:
raise f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead"
fragments = [x[0] for x in flattened_prompt.children]
attention_weights = [x[1] for x in flattened_prompt.children]
print(fragments, attention_weights)
return model.get_learned_conditioning([fragments], attention_weights=[attention_weights])
for part,weight in zip(parsed_conjunction.prompts, parsed_conjunction.weights):
if type(part) is Blend:
blend:Blend = part
embeddings_to_blend = None
for flattened_prompt in blend.prompts:
this_embedding = make_embeddings_for_flattened_prompt(flattened_prompt)
embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat((embeddings_to_blend, this_embedding))
blended_embeddings = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0), blend.weights, normalize=blend.normalize_weights)
conditioning_list.append((blended_embeddings, weight))
else:
flattened_prompt: FlattenedPrompt = part
embeddings = make_embeddings_for_flattened_prompt(flattened_prompt)
conditioning_list.append((embeddings, weight))
return conditioning_list
positive_conditioning_list = build_conditioning_list(prompt_string_cleaned)
negative_conditioning_list = build_conditioning_list(unconditioned_words)
if len(negative_conditioning_list) == 0:
negative_conditioning = model.get_learned_conditioning([['']], attention_weights=[[1]])
else:
if len(negative_conditioning_list)>1:
print("cannot do conjunctions on unconditioning for now")
negative_conditioning = negative_conditioning_list[0][0]
#positive_conditioning_list.append((get_blend_prompts_and_weights(prompt), this_weight))
#print("got empty_conditionining with shape", empty_conditioning.shape, "c[0][0] with shape", positive_conditioning[0][0].shape)
# "unconditioned" means "the conditioning tensor is empty"
uc = negative_conditioning
c = positive_conditioning_list
if len(weighted_subprompts) > 1:
# i dont know if this is correct.. but it works
c = torch.zeros_like(uc)
# normalize each "sub prompt" and add it
for subprompt, weight in weighted_subprompts:
log_tokenization(subprompt, model, log_tokens, weight)
c = torch.add(
c,
model.get_learned_conditioning([subprompt]),
alpha=weight,
)
else: # just standard 1 prompt
log_tokenization(prompt, model, log_tokens, 1)
c = model.get_learned_conditioning([prompt])
uc = model.get_learned_conditioning([unconditioned_words])
return (uc, c) return (uc, c)
def split_weighted_subprompts(text, skip_normalize=False)->list: def split_weighted_subprompts(text, skip_normalize=False)->list:

331
ldm/invoke/prompt_parser.py Normal file
View File

@ -0,0 +1,331 @@
import pyparsing
import pyparsing as pp
from pyparsing import original_text_for
class Prompt():
def __init__(self, parts: list):
for c in parts:
allowed_types = [Fragment, Attention, CFGScale]
if type(c) not in allowed_types:
raise PromptParser.ParsingException(f"Prompt cannot contain {type(c)}, only {allowed_types} are allowed")
self.children = parts
def __repr__(self):
return f"Prompt:{self.children}"
def __eq__(self, other):
return type(other) is Prompt and other.children == self.children
class FlattenedPrompt():
def __init__(self, parts: list):
# verify type correctness
for c in parts:
if type(c) is not tuple:
raise PromptParser.ParsingException(
f"FlattenedPrompt cannot contain {type(c)}, only ('text', weight) tuples are allowed")
text = c[0]
weight = c[1]
if type(text) is not str:
raise PromptParser.ParsingException(f"FlattenedPrompt cannot contain {type(c)}, only ('text', weight) tuples are allowed")
if type(weight) is not float and type(weight) is not int:
raise PromptParser.ParsingException(
f"FlattenedPrompt cannot contain {type(c)}, only ('text', weight) tuples are allowed")
# all looks good
self.children = parts
def __repr__(self):
return f"FlattenedPrompt:{self.children}"
def __eq__(self, other):
return type(other) is FlattenedPrompt and other.children == self.children
class Attention():
def __init__(self, weight: float, children: list):
self.weight = weight
self.children = children
#print(f"A: requested attention '{children}' to {weight}")
def __repr__(self):
return f"Attention:'{self.children}' @ {self.weight}"
def __eq__(self, other):
return type(other) is Attention and other.weight == self.weight and other.fragment == self.fragment
class CFGScale():
def __init__(self, scale_factor: float, fragment: str):
self.fragment = fragment
self.scale_factor = scale_factor
#print(f"S: requested CFGScale '{fragment}' x {scale_factor}")
def __repr__(self):
return f"CFGScale:'{self.fragment}' x {self.scale_factor}"
def __eq__(self, other):
return type(other) is CFGScale and other.scale_factor == self.scale_factor and other.fragment == self.fragment
class Fragment():
def __init__(self, text: str):
assert(type(text) is str)
self.text = text
def __repr__(self):
return "Fragment:'"+self.text+"'"
def __eq__(self, other):
return type(other) is Fragment and other.text == self.text
class Conjunction():
def __init__(self, prompts: list, weights: list = None):
# force everything to be a Prompt
#print("making conjunction with", parts)
self.prompts = [x if (type(x) is Prompt or type(x) is Blend or type(x) is FlattenedPrompt)
else Prompt(x) for x in prompts]
self.weights = [1.0]*len(self.prompts) if weights is None else list(weights)
if len(self.weights) != len(self.prompts):
raise PromptParser.ParsingException(f"while parsing Conjunction: mismatched parts/weights counts {prompts}, {weights}")
self.type = 'AND'
def __repr__(self):
return f"Conjunction:{self.prompts} | weights {self.weights}"
def __eq__(self, other):
return type(other) is Conjunction \
and other.prompts == self.prompts \
and other.weights == self.weights
class Blend():
def __init__(self, prompts: list, weights: list[float], normalize_weights: bool=True):
#print("making Blend with prompts", prompts, "and weights", weights)
if len(prompts) != len(weights):
raise PromptParser.ParsingException(f"while parsing Blend: mismatched prompts/weights counts {prompts}, {weights}")
for c in prompts:
if type(c) is not Prompt and type(c) is not FlattenedPrompt:
raise(PromptParser.ParsingException(f"{type(c)} cannot be added to a Blend, only Prompts or FlattenedPrompts"))
# upcast all lists to Prompt objects
self.prompts = [x if (type(x) is Prompt or type(x) is FlattenedPrompt)
else Prompt(x) for x in prompts]
self.prompts = prompts
self.weights = weights
self.normalize_weights = normalize_weights
def __repr__(self):
return f"Blend:{self.prompts} | weights {self.weights}"
def __eq__(self, other):
return other.__repr__() == self.__repr__()
class PromptParser():
class ParsingException(Exception):
pass
def __init__(self, attention_plus_base=1.1, attention_minus_base=0.9):
self.attention_plus_base = attention_plus_base
self.attention_minus_base = attention_minus_base
self.root = self.build_parser_logic()
def parse(self, prompt: str) -> [list]:
'''
:param prompt: The prompt string to parse
:return: a tuple
'''
#print(f"!!parsing '{prompt}'")
if len(prompt.strip()) == 0:
return Conjunction(prompts=[FlattenedPrompt([('', 1.0)])], weights=[1.0])
root = self.root.parse_string(prompt)
#print(f"'{prompt}' parsed to root", root)
#fused = fuse_fragments(parts)
#print("fused to", fused)
return self.flatten(root[0])
def flatten(self, root: Conjunction):
def fuse_fragments(items):
# print("fusing fragments in ", items)
result = []
for x in items:
last_weight = result[-1][1] if len(result) > 0 else None
this_text = x[0]
this_weight = x[1]
if last_weight is not None and last_weight == this_weight:
last_text = result[-1][0]
result[-1] = (last_text + ' ' + this_text, last_weight)
else:
result.append(x)
return result
def flatten_internal(node, weight_scale, results, prefix):
#print(prefix + "flattening", node, "...")
if type(node) is pp.ParseResults:
for x in node:
results = flatten_internal(x, weight_scale, results, prefix+'pr')
#print(prefix, " ParseResults expanded, results is now", results)
elif type(node) is Fragment:
results.append((node.text, float(weight_scale)))
elif type(node) is Attention:
#if node.weight < 1:
# todo: inject a blend when flattening attention with weight <1"
for c in node.children:
results = flatten_internal(c, weight_scale*node.weight, results, prefix+' ')
elif type(node) is Blend:
flattened_subprompts = []
#print(" flattening blend with prompts", node.prompts, "weights", node.weights)
for prompt in node.prompts:
# prompt is a list
flattened_subprompts = flatten_internal(prompt, weight_scale, flattened_subprompts, prefix+'B ')
results += [Blend(prompts=flattened_subprompts, weights=node.weights)]
elif type(node) is Prompt:
#print(prefix + "about to flatten Prompt with children", node.children)
flattened_prompt = []
for child in node.children:
flattened_prompt = flatten_internal(child, weight_scale, flattened_prompt, prefix+'P ')
results += [FlattenedPrompt(parts=fuse_fragments(flattened_prompt))]
#print(prefix + "after flattening Prompt, results is", results)
else:
raise PromptParser.ParsingException(f"unhandled node type {type(node)} when flattening {node}")
#print(prefix + "-> after flattening", type(node), "results is", results)
return results
#print("flattening", root)
flattened_parts = []
for part in root.prompts:
flattened_parts += flatten_internal(part, 1.0, [], ' C| ')
weights = root.weights
return Conjunction(flattened_parts, weights)
def build_parser_logic(self):
lparen = pp.Literal("(").suppress()
rparen = pp.Literal(")").suppress()
# accepts int or float notation, always maps to float
number = pyparsing.pyparsing_common.real | pp.Word(pp.nums).set_parse_action(pp.token_map(float))
SPACE_CHARS = ' \t\n'
prompt_part = pp.Forward()
word = pp.Forward()
def make_fragment(x):
#print("### making fragment for", x)
if type(x) is str:
return Fragment(x)
elif type(x) is pp.ParseResults or type(x) is list:
return Fragment(' '.join([s for s in x]))
else:
raise PromptParser.ParsingException("Cannot make fragment from " + str(x))
# attention control of the form +(phrase) / -(phrase) / <weight>(phrase)
# phrase can be multiple words, can have multiple +/- signs to increase the effect or type a floating point or integer weight
attention = pp.Forward()
attention_head = (number | pp.Word('+') | pp.Word('-'))\
.set_name("attention_head")\
.set_debug(False)
fragment_inside_attention = pp.CharsNotIn(SPACE_CHARS+'()')\
.set_parse_action(make_fragment)\
.set_name("fragment_inside_attention")\
.set_debug(False)
attention_with_parens = pp.Forward()
attention_with_parens_body = pp.nested_expr(content=pp.delimited_list((attention_with_parens | fragment_inside_attention), delim=SPACE_CHARS))
attention_with_parens << (attention_head + attention_with_parens_body)
def make_attention(x):
# print("making Attention from parsing with args", x0, x1)
weight = 1
# number(str)
if type(x[0]) is float or type(x[0]) is int:
weight = float(x[0])
# +(str) or -(str) or +str or -str
elif type(x[0]) is str:
base = self.attention_plus_base if x[0][0] == '+' else self.attention_minus_base
weight = pow(base, len(x[0]))
# print("Making attention with children of type", [str(type(x)) for x in x1])
return Attention(weight=weight, children=x[1])
attention_with_parens.set_parse_action(make_attention)\
.set_name("attention_with_parens")\
.set_debug(False)
# attention control of the form ++word --word (no parens)
attention_without_parens = (
(pp.Word('+') | pp.Word('-')) +
pp.CharsNotIn(SPACE_CHARS+'()').set_parse_action(lambda x: [[make_fragment(x)]])
)\
.set_name("attention_without_parens")\
.set_debug(False)
attention_without_parens.set_parse_action(make_attention)
attention << (attention_with_parens | attention_without_parens)\
.set_name("attention")\
.set_debug(False)
# fragments of text with no attention control
word << pp.Word(pp.printables).set_parse_action(lambda x: Fragment(' '.join([s for s in x])))
word.set_name("word")
word.set_debug(False)
prompt_part << (attention | word)
prompt_part.set_debug(False)
prompt_part.set_name("prompt_part")
# root prompt definition
prompt = pp.Group(pp.OneOrMore(prompt_part))\
.set_parse_action(lambda x: Prompt(x[0]))
# weighted blend of prompts
# ("promptA", "promptB").blend(a, b) where "promptA" and "promptB" are valid prompts and a and b are float or
# int weights.
# can specify more terms eg ("promptA", "promptB", "promptC").blend(a,b,c)
def make_prompt_from_quoted_string(x):
#print(' got quoted prompt', x)
x_unquoted = x[0][1:-1]
if len(x_unquoted.strip()) == 0:
# print(' b : just an empty string')
return Prompt([Fragment('')])
# print(' b parsing ', c_unquoted)
x_parsed = prompt.parse_string(x_unquoted)
#print(" quoted prompt was parsed to", type(x_parsed),":", x_parsed)
return x_parsed[0]
quoted_prompt = pp.dbl_quoted_string.set_parse_action(make_prompt_from_quoted_string)
quoted_prompt.set_name('quoted_prompt')
blend_terms = pp.delimited_list(quoted_prompt).set_name('blend_terms')
blend_weights = pp.delimited_list(number).set_name('blend_weights')
blend = pp.Group(lparen + pp.Group(blend_terms) + rparen
+ pp.Literal(".blend").suppress()
+ lparen + pp.Group(blend_weights) + rparen).set_name('blend')
blend.set_debug(False)
blend.set_parse_action(lambda x: Blend(prompts=x[0][0], weights=x[0][1]))
conjunction_terms = blend_terms.copy().set_name('conjunction_terms')
conjunction_weights = blend_weights.copy().set_name('conjunction_weights')
conjunction_with_parens_and_quotes = pp.Group(lparen + pp.Group(conjunction_terms) + rparen
+ pp.Literal(".and").suppress()
+ lparen + pp.Optional(pp.Group(conjunction_weights)) + rparen).set_name('conjunction')
def make_conjunction(x):
parts_raw = x[0][0]
weights = x[0][1] if len(x[0])>1 else [1.0]*len(parts_raw)
parts = [part for part in parts_raw]
return Conjunction(parts, weights)
conjunction_with_parens_and_quotes.set_parse_action(make_conjunction)
implicit_conjunction = pp.OneOrMore(blend | prompt)
implicit_conjunction.set_parse_action(lambda x: Conjunction(x))
conjunction = conjunction_with_parens_and_quotes | implicit_conjunction
conjunction.set_debug(False)
# top-level is a conjunction of one or more blends or prompts
return conjunction

View File

@ -34,23 +34,21 @@ class DDIMSampler(Sampler):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
if ( if (
unconditional_conditioning is None (unconditional_conditioning is None
or unconditional_guidance_scale == 1.0 or unconditional_guidance_scale == 1.0)
and c is not list
): ):
e_t = self.model.apply_model(x, t, c) e_t = self.model.apply_model(x, t, c)
else: else:
x_in = torch.cat([x] * 2) e_t = self.apply_weighted_conditioning_list(x, t, self.model.apply_model, unconditional_conditioning, c, unconditional_guidance_scale)
t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (
e_t - e_t_uncond
)
if score_corrector is not None: if score_corrector is not None:
assert self.model.parameterization == 'eps' assert self.model.parameterization == 'eps'
if c is list and len(c)>1:
print("warning: ddim score modifier currently ignores all but the first part of the prompt conjunction, this is probably wrong")
corrector_c = [c[0][0] if c is list else c]
e_t = score_corrector.modify_score( e_t = score_corrector.modify_score(
self.model, e_t, x, t, c, **corrector_kwargs self.model, e_t, x, t, corrector_c, **corrector_kwargs
) )
alphas = ( alphas = (

View File

@ -820,13 +820,13 @@ class LatentDiffusion(DDPM):
) )
return self.scale_factor * z return self.scale_factor * z
def get_learned_conditioning(self, c): def get_learned_conditioning(self, c, attention_weights=None):
if self.cond_stage_forward is None: if self.cond_stage_forward is None:
if hasattr(self.cond_stage_model, 'encode') and callable( if hasattr(self.cond_stage_model, 'encode') and callable(
self.cond_stage_model.encode self.cond_stage_model.encode
): ):
c = self.cond_stage_model.encode( c = self.cond_stage_model.encode(
c, embedding_manager=self.embedding_manager c, embedding_manager=self.embedding_manager, attention_weights=attention_weights
) )
if isinstance(c, DiagonalGaussianDistribution): if isinstance(c, DiagonalGaussianDistribution):
c = c.mode() c = c.mode()

View File

@ -38,7 +38,8 @@ class CFGDenoiser(nn.Module):
x_in = torch.cat([x] * 2) x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2) sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond]) cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) unconditioned_x, conditioned_x = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
if self.warmup < self.warmup_max: if self.warmup < self.warmup_max:
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max)) thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
self.warmup += 1 self.warmup += 1
@ -46,7 +47,28 @@ class CFGDenoiser(nn.Module):
thresh = self.threshold thresh = self.threshold
if thresh > self.threshold: if thresh > self.threshold:
thresh = self.threshold thresh = self.threshold
return cfg_apply_threshold(uncond + (cond - uncond) * cond_scale, thresh)
# damian0815 thinking out loud notes:
# b + (a - b)*scale
# starting at the output that emerges applying the negative prompt (by default ''),
# (-> this is why the unconditioning feels like hammer)
# move toward the positive prompt by an amount controlled by cond_scale.
return cfg_apply_threshold(unconditioned_x + (conditioned_x - unconditioned_x) * cond_scale, thresh)
class ProgrammableCFGDenoiser(CFGDenoiser):
def forward(self, x, sigma, uncond, cond, cond_scale):
forward_lambda = lambda x, t, c: self.inner_model(x, t, cond=c)
x_new = Sampler.apply_weighted_conditioning_list(x, sigma, forward_lambda, uncond, cond, cond_scale)
if self.warmup < self.warmup_max:
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
self.warmup += 1
else:
thresh = self.threshold
if thresh > self.threshold:
thresh = self.threshold
return cfg_apply_threshold(x_new, threshold=thresh)
class KSampler(Sampler): class KSampler(Sampler):
@ -181,7 +203,6 @@ class KSampler(Sampler):
) )
# sigmas are set up in make_schedule - we take the last steps items # sigmas are set up in make_schedule - we take the last steps items
total_steps = len(self.sigmas)
sigmas = self.sigmas[-S-1:] sigmas = self.sigmas[-S-1:]
# x_T is variation noise. When an init image is provided (in x0) we need to add # x_T is variation noise. When an init image is provided (in x0) we need to add
@ -194,7 +215,7 @@ class KSampler(Sampler):
else: else:
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0]
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10)) model_wrap_cfg = ProgrammableCFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
extra_args = { extra_args = {
'cond': conditioning, 'cond': conditioning,
'uncond': unconditional_conditioning, 'uncond': unconditional_conditioning,

View File

@ -4,6 +4,8 @@ ldm.models.diffusion.sampler
Base class for ldm.models.diffusion.ddim, ldm.models.diffusion.ksampler, etc Base class for ldm.models.diffusion.ddim, ldm.models.diffusion.ksampler, etc
''' '''
from math import ceil
import torch import torch
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
@ -411,3 +413,54 @@ class Sampler(object):
return self.model.inner_model.q_sample(x0,ts) return self.model.inner_model.q_sample(x0,ts)
''' '''
return self.model.q_sample(x0,ts) return self.model.q_sample(x0,ts)
@classmethod
def apply_weighted_conditioning_list(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale):
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2) # aka sigmas
deltas = None
uncond_latents = None
weighted_cond_list = c_or_weighted_c_list if type(c_or_weighted_c_list) is list else [(c_or_weighted_c_list, 1)]
# below is fugly omg
num_actual_conditionings = len(c_or_weighted_c_list)
conditionings = [uc] + [c for c,weight in weighted_cond_list]
weights = [1] + [weight for c,weight in weighted_cond_list]
chunk_count = ceil(len(conditionings)/2)
assert(len(conditionings)>=2, "need at least one uncond and one cond")
deltas = None
for chunk_index in range(chunk_count):
offset = chunk_index*2
chunk_size = min(2, len(conditionings)-offset)
if chunk_size == 1:
c_in = conditionings[offset]
latents_a = forward_func(x_in[:-1], t_in[:-1], c_in)
latents_b = None
else:
c_in = torch.cat(conditionings[offset:offset+2])
latents_a, latents_b = forward_func(x_in, t_in, c_in).chunk(2)
# first chunk is guaranteed to be 2 entries: uncond_latents + first conditioining
if chunk_index == 0:
uncond_latents = latents_a
deltas = latents_b - uncond_latents
else:
deltas = torch.cat((deltas, latents_a - uncond_latents))
if latents_b is not None:
deltas = torch.cat((deltas, latents_b - uncond_latents))
# merge the weighted deltas together into a single merged delta
per_delta_weights = torch.tensor(weights[1:], dtype=deltas.dtype, device=deltas.device)
normalize = False
if normalize:
per_delta_weights /= torch.sum(per_delta_weights)
reshaped_weights = per_delta_weights.reshape(per_delta_weights.shape + (1, 1, 1))
deltas_merged = torch.sum(deltas * reshaped_weights, dim=0, keepdim=True)
# old_return_value = super().forward(x, sigma, uncond, cond, cond_scale)
# assert(0 == len(torch.nonzero(old_return_value - (uncond_latents + deltas_merged * cond_scale))))
return uncond_latents + deltas_merged * global_guidance_scale

View File

@ -1,3 +1,5 @@
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from functools import partial from functools import partial
@ -454,6 +456,207 @@ class FrozenCLIPEmbedder(AbstractEncoder):
def encode(self, text, **kwargs): def encode(self, text, **kwargs):
return self(text, **kwargs) return self(text, **kwargs)
class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
attention_weights_key = "attention_weights"
def build_token_list_fragment(self, fragment: str, weight: float) -> (torch.Tensor, torch.Tensor):
batch_encoding = self.tokenizer(
fragment,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding='none',
return_tensors='pt',
)
return batch_encoding, torch.ones_like(batch_encoding) * weight
def forward(self, text: list, **kwargs):
'''
:param text: A batch of prompt strings, or, a batch of lists of fragments of prompt strings to which different
weights shall be applied.
:param kwargs: If the keyword arg "attention_weights" is passed, it shall contain a batch of lists of weights
for the prompt fragments. In this case text must contain batches of lists of prompt fragments.
:return: A tensor of shape (B, 77, 768) containing weighted embeddings
'''
if self.attention_weights_key not in kwargs:
# fallback to base class implementation
return super().forward(text, **kwargs)
attention_weights = kwargs[self.attention_weights_key]
# self.transformer doesn't like receiving "attention_weights" as an argument
kwargs.pop(self.attention_weights_key)
batch_z = None
for fragments, weights in zip(text, attention_weights):
# First, weight tokens in individual fragments by scaling the feature vectors as requested (effectively
# applying a multiplier to the CFG scale on a per-token basis).
# For tokens weighted<1, intuitively we want SD to become not merely *less* interested in the concept
# captured by the fragment but actually *dis*interested in it (a 0.01 interest in "red" is still an active
# interest, however small, in redness; what the user probably intends when they attach the number 0.01 to
# "red" is to tell SD that it should almost completely *ignore* redness).
# To do this, the embedding is lerped away from base_embedding in the direction of an embedding for a prompt
# string from which the low-weighted fragment has been simply removed. The closer the weight is to zero, the
# closer the resulting embedding is to an embedding for a prompt that simply lacks this fragment.
# handle weights >=1
tokens, per_token_weights = self.get_tokens_and_weights(fragments, weights)
base_embedding = self.build_weighted_embedding_tensor(tokens, per_token_weights, **kwargs)
# this is our starting point
embeddings = base_embedding.unsqueeze(0)
per_embedding_weights = [1.0]
# now handle weights <1
# Do this by building extra embeddings tensors that lack the words being <1 weighted. These will be lerped
# with the embeddings tensors that have the words, such that if the weight of a word is 0.5, the resulting
# embedding will be exactly half-way between the unweighted prompt and the prompt with the <1 weighted words
# removed.
# eg for "mountain:1 man:0.5", intuitively the "man" should be "half-gone". therefore, append an embedding
# for "mountain" (i.e. without "man") to the already-produced embedding for "mountain man", and weight it
# such that the resulting lerped embedding is exactly half-way between "mountain man" and "mountain".
for index, fragment_weight in enumerate(weights):
if fragment_weight < 1:
fragments_without_this = fragments[:index] + fragments[index+1:]
weights_without_this = weights[:index] + weights[index+1:]
tokens, per_token_weights = self.get_tokens_and_weights(fragments_without_this, weights_without_this)
embedding_without_this = self.build_weighted_embedding_tensor(tokens, per_token_weights, **kwargs)
embeddings = torch.cat((embeddings, embedding_without_this.unsqueeze(0)), dim=1)
# weight of the embedding *without* this fragment gets *stronger* as its weight approaches 0
# if fragment_weight = 0, basically we want embedding_without_this to completely overwhelm base_embedding
# therefore:
# fragment_weight = 1: we are at base_z => lerp weight 0
# fragment_weight = 0.5: we are halfway between base_z and here => lerp weight 1
# fragment_weight = 0: we're now entirely overriding base_z ==> lerp weight inf
# so let's use tan(), because:
# tan is 0.0 at 0,
# 1.0 at PI/4, and
# inf at PI/2
# -> tan((1-weight)*PI/2) should give us ideal lerp weights
epsilon = 1e-9
fragment_weight = max(epsilon, fragment_weight) # inf is bad
embedding_lerp_weight = math.tan((1.0 - fragment_weight) * math.pi / 2)
# todo handle negative weight?
per_embedding_weights.append(embedding_lerp_weight)
lerped_embeddings = self.apply_embedding_weights(embeddings, per_embedding_weights, normalize=True).squeeze(0)
print(f"assembled tokens for '{fragments}' into tensor of shape {lerped_embeddings.shape}")
# append to batch
batch_z = lerped_embeddings.unsqueeze(0) if batch_z is None else torch.cat((batch_z, lerped_embeddings.unsqueeze(0)), dim=1)
# should have shape (B, 77, 768)
print(f"assembled all tokens into tensor of shape {batch_z.shape}")
return batch_z
@classmethod
def apply_embedding_weights(self, embeddings: torch.Tensor, per_embedding_weights: list[float], normalize:bool) -> torch.Tensor:
per_embedding_weights = torch.tensor(per_embedding_weights, dtype=embeddings.dtype, device=embeddings.device)
if normalize:
per_embedding_weights = per_embedding_weights / torch.sum(per_embedding_weights)
reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1, 1,))
#reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1,1,)).expand(embeddings.shape)
return torch.sum(embeddings * reshaped_weights, dim=1)
# lerped embeddings has shape (77, 768)
def get_tokens_and_weights(self, fragments: list[str], weights: list[float]) -> (torch.Tensor, torch.Tensor):
'''
:param fragments:
:param weights: Per-fragment weights (CFG scaling). No need for these to be normalized. They will not be normalized here and that's fine.
:return:
'''
# empty is meaningful
if len(fragments) == 0 and len(weights) == 0:
fragments = ['']
weights = [1]
item_encodings = self.tokenizer(
fragments,
truncation=True,
max_length=self.max_length,
return_overflowing_tokens=False,
padding='do_not_pad',
return_tensors=None, # just give me a list of ints
)['input_ids']
all_tokens = []
per_token_weights = []
print("all fragments:", fragments, weights)
for index, fragment in enumerate(item_encodings):
weight = weights[index]
print("processing fragment", fragment, weight)
fragment_tokens = item_encodings[index]
print("fragment", fragment, "processed to", fragment_tokens)
# trim bos and eos markers before appending
all_tokens.extend(fragment_tokens[1:-1])
per_token_weights.extend([weight] * (len(fragment_tokens) - 2))
if len(all_tokens) > self.max_length - 2:
print("prompt is too long and has been truncated")
all_tokens = all_tokens[:self.max_length - 2]
# pad out to a 77-entry array: [eos_token, <prompt tokens>, eos_token, ..., eos_token]
# (77 = self.max_length)
pad_length = self.max_length - 1 - len(all_tokens)
all_tokens.insert(0, self.tokenizer.bos_token_id)
all_tokens.extend([self.tokenizer.eos_token_id] * pad_length)
per_token_weights.insert(0, 1)
per_token_weights.extend([1] * pad_length)
all_tokens_tensor = torch.tensor(all_tokens, dtype=torch.long).to(self.device)
per_token_weights_tensor = torch.tensor(per_token_weights, dtype=torch.float32).to(self.device)
print(f"assembled all_tokens_tensor with shape {all_tokens_tensor.shape}")
return all_tokens_tensor, per_token_weights_tensor
def build_weighted_embedding_tensor(self, tokens: torch.Tensor, per_token_weights: torch.Tensor, weight_delta_from_empty=True, **kwargs) -> torch.Tensor:
'''
Build a tensor representing the passed-in tokens, each of which has a weight.
:param tokens: A tensor of shape (77) containing token ids (integers)
:param per_token_weights: A tensor of shape (77) containing weights (floats)
:param method: Whether to multiply the whole feature vector for each token or just its distance from an "empty" feature vector
:param kwargs: passed on to self.transformer()
:return: A tensor of shape (1, 77, 768) representing the requested weighted embeddings.
'''
#print(f"building weighted embedding tensor for {tokens} with weights {per_token_weights}")
z = self.transformer(input_ids=tokens.unsqueeze(0), **kwargs)
batch_weights_expanded = per_token_weights.reshape(per_token_weights.shape + (1,)).expand(z.shape)
if weight_delta_from_empty:
empty_tokens = self.tokenizer([''] * z.shape[0],
truncation=True,
max_length=self.max_length,
padding='max_length',
return_tensors='pt'
)['input_ids'].to(self.device)
empty_z = self.transformer(input_ids=empty_tokens, **kwargs)
z_delta_from_empty = z - empty_z
weighted_z = empty_z + (z_delta_from_empty * batch_weights_expanded)
weighted_z_delta_from_empty = (weighted_z-empty_z)
print("weighted z has delta from empty with sum", weighted_z_delta_from_empty.sum().item(), "mean", weighted_z_delta_from_empty.mean().item() )
#print("using empty-delta method, first 5 rows:")
#print(weighted_z[:5])
return weighted_z
else:
original_mean = z.mean()
z *= batch_weights_expanded
after_weighting_mean = z.mean()
# correct the mean. not sure if this is right but it's what the automatic1111 fork of SD does
mean_correction_factor = original_mean/after_weighting_mean
z *= mean_correction_factor
return z
class FrozenCLIPTextEmbedder(nn.Module): class FrozenCLIPTextEmbedder(nn.Module):
""" """

136
tests/test_prompt_parser.py Normal file
View File

@ -0,0 +1,136 @@
import unittest
from ldm.invoke.prompt_parser import PromptParser, Blend, Conjunction, FlattenedPrompt
def parse_prompt(prompt_string):
pp = PromptParser()
#print(f"parsing '{prompt_string}'")
parse_result = pp.parse(prompt_string)
#print(f"-> parsed '{prompt_string}' to {parse_result}")
return parse_result
class PromptParserTestCase(unittest.TestCase):
def test_empty(self):
self.assertEqual(Conjunction([FlattenedPrompt([('', 1)])]), parse_prompt(''))
def test_basic(self):
self.assertEqual(Conjunction([FlattenedPrompt([('fire (flames)', 1)])]), parse_prompt("fire (flames)"))
self.assertEqual(Conjunction([FlattenedPrompt([("fire flames", 1)])]), parse_prompt("fire flames"))
self.assertEqual(Conjunction([FlattenedPrompt([("fire, flames", 1)])]), parse_prompt("fire, flames"))
self.assertEqual(Conjunction([FlattenedPrompt([("fire, flames , fire", 1)])]), parse_prompt("fire, flames , fire"))
def test_attention(self):
self.assertEqual(Conjunction([FlattenedPrompt([('flames', 0.5)])]), parse_prompt("0.5(flames)"))
self.assertEqual(Conjunction([FlattenedPrompt([('fire flames', 0.5)])]), parse_prompt("0.5(fire flames)"))
self.assertEqual(Conjunction([FlattenedPrompt([('flames', 1.1)])]), parse_prompt("+(flames)"))
self.assertEqual(Conjunction([FlattenedPrompt([('flames', 0.9)])]), parse_prompt("-(flames)"))
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1), ('flames', 0.5)])]), parse_prompt("fire 0.5(flames)"))
self.assertEqual(Conjunction([FlattenedPrompt([('flames', pow(1.1, 2))])]), parse_prompt("++(flames)"))
self.assertEqual(Conjunction([FlattenedPrompt([('flames', pow(0.9, 2))])]), parse_prompt("--(flames)"))
self.assertEqual(Conjunction([FlattenedPrompt([('flowers', pow(0.9, 3)), ('flames', pow(1.1, 3))])]), parse_prompt("---(flowers) +++flames"))
self.assertEqual(Conjunction([FlattenedPrompt([('flowers', pow(0.9, 3)), ('flames', pow(1.1, 3))])]), parse_prompt("---(flowers) +++flames"))
self.assertEqual(Conjunction([FlattenedPrompt([('flowers', pow(0.9, 3)), ('flames+', pow(1.1, 3))])]),
parse_prompt("---(flowers) +++flames+"))
self.assertEqual(Conjunction([FlattenedPrompt([('pretty flowers', 1.1)])]),
parse_prompt("+(pretty flowers)"))
self.assertEqual(Conjunction([FlattenedPrompt([('pretty flowers', 1.1), (', the flames are too hot', 1)])]),
parse_prompt("+(pretty flowers), the flames are too hot"))
def test_no_parens_attention_runon(self):
self.assertEqual(Conjunction([FlattenedPrompt([('fire', pow(1.1, 2)), ('flames', 1.0)])]), parse_prompt("++fire flames"))
self.assertEqual(Conjunction([FlattenedPrompt([('fire', pow(0.9, 2)), ('flames', 1.0)])]), parse_prompt("--fire flames"))
self.assertEqual(Conjunction([FlattenedPrompt([('flowers', 1.0), ('fire', pow(1.1, 2)), ('flames', 1.0)])]), parse_prompt("flowers ++fire flames"))
self.assertEqual(Conjunction([FlattenedPrompt([('flowers', 1.0), ('fire', pow(0.9, 2)), ('flames', 1.0)])]), parse_prompt("flowers --fire flames"))
def test_explicit_conjunction(self):
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])]), parse_prompt('("fire", "flames").and(1,1)'))
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])]), parse_prompt('("fire", "flames").and()'))
self.assertEqual(
Conjunction([FlattenedPrompt([('fire flames', 1.0)]), FlattenedPrompt([('mountain man', 1.0)])]), parse_prompt('("fire flames", "mountain man").and()'))
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 2.0)]), FlattenedPrompt([('flames', 0.9)])]), parse_prompt('("2.0(fire)", "-flames").and()'))
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)]),
FlattenedPrompt([('mountain man', 1.0)])]), parse_prompt('("fire", "flames", "mountain man").and()'))
def test_conjunction_weights(self):
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])], weights=[2.0,1.0]), parse_prompt('("fire", "flames").and(2,1)'))
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])], weights=[1.0,2.0]), parse_prompt('("fire", "flames").and(1,2)'))
with self.assertRaises(PromptParser.ParsingException):
parse_prompt('("fire", "flames").and(2)')
parse_prompt('("fire", "flames").and(2,1,2)')
def test_complex_conjunction(self):
self.assertEqual(Conjunction([FlattenedPrompt([("mountain man", 1.0)]), FlattenedPrompt([("a person with a hat", 1.0), ("riding a bicycle", pow(1.1,2))])], weights=[0.5, 0.5]),
parse_prompt("(\"mountain man\", \"a person with a hat ++(riding a bicycle)\").and(0.5, 0.5)"))
def test_badly_formed(self):
def make_untouched_prompt(prompt):
return Conjunction([FlattenedPrompt([(prompt, 1.0)])])
def assert_if_prompt_string_not_untouched(prompt):
self.assertEqual(make_untouched_prompt(prompt), parse_prompt(prompt))
assert_if_prompt_string_not_untouched('a test prompt')
assert_if_prompt_string_not_untouched('a badly (formed test prompt')
assert_if_prompt_string_not_untouched('a badly formed test+ prompt')
assert_if_prompt_string_not_untouched('a badly (formed test+ prompt')
assert_if_prompt_string_not_untouched('a badly (formed test+ )prompt')
assert_if_prompt_string_not_untouched('a badly (formed test+ )prompt')
assert_if_prompt_string_not_untouched('(((a badly (formed test+ )prompt')
assert_if_prompt_string_not_untouched('(a (ba)dly (f)ormed test+ prompt')
self.assertEqual(Conjunction([FlattenedPrompt([('(a (ba)dly (f)ormed test+', 1.0), ('prompt', 1.1)])]),
parse_prompt('(a (ba)dly (f)ormed test+ +prompt'))
self.assertEqual(Conjunction([Blend([FlattenedPrompt([('((a badly (formed test+', 1.0)])], weights=[1.0])]),
parse_prompt('("((a badly (formed test+ ").blend(1.0)'))
def test_blend(self):
self.assertEqual(Conjunction(
[Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('fire flames', 1.0)])], [0.7, 0.3])]),
parse_prompt("(\"fire\", \"fire flames\").blend(0.7, 0.3)")
)
self.assertEqual(Conjunction([Blend(
[FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('fire flames', 1.0)]), FlattenedPrompt([('hi', 1.0)])],
[0.7, 0.3, 1.0])]),
parse_prompt("(\"fire\", \"fire flames\", \"hi\").blend(0.7, 0.3, 1.0)")
)
self.assertEqual(Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]),
FlattenedPrompt([('fire flames', 1.0), ('hot', pow(1.1, 2))]),
FlattenedPrompt([('hi', 1.0)])],
weights=[0.7, 0.3, 1.0])]),
parse_prompt("(\"fire\", \"fire flames ++(hot)\", \"hi\").blend(0.7, 0.3, 1.0)")
)
# blend a single entry is not a failure
self.assertEqual(Conjunction([Blend([FlattenedPrompt([('fire', 1.0)])], [0.7])]),
parse_prompt("(\"fire\").blend(0.7)")
)
# blend with empty
self.assertEqual(
Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]),
parse_prompt("(\"fire\", \"\").blend(0.7, 1)")
)
self.assertEqual(
Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]),
parse_prompt("(\"fire\", \" \").blend(0.7, 1)")
)
self.assertEqual(
Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]),
parse_prompt("(\"fire\", \" \").blend(0.7, 1)")
)
self.assertEqual(
Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([(',', 1.0)])], [0.7, 1.0])]),
parse_prompt("(\"fire\", \" , \").blend(0.7, 1)")
)
def test_nested(self):
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0), ('flames', 2.0), ('trees', 3.0)])]),
parse_prompt('fire 2.0(flames 1.5(trees))'))
self.assertEqual(Conjunction([Blend(prompts=[FlattenedPrompt([('fire', 1.0), ('flames', 1.2100000000000002)]),
FlattenedPrompt([('mountain', 1.0), ('man', 2.0)])],
weights=[1.0, 1.0])]),
parse_prompt('("fire ++(flames)", "mountain 2(man)").blend(1,1)'))
if __name__ == '__main__':
unittest.main()