mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Squashed commit of the following:
commit 9bb0b5d0036c4dffbb72ce11e097fae4ab63defd Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Sat Oct 15 23:43:41 2022 +0200 undo local_files_only stuff commit eed93f5d30c34cfccaf7497618ae9af17a5ecfbb Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Sat Oct 15 23:40:37 2022 +0200 Revert "Merge branch 'development-invoke' into fix-prompts" This reverts commit 7c40892a9f184f7e216f14d14feb0411c5a90e24, reversing changes made to e3f2dd62b0548ca6988818ef058093a4f5b022f2. commit f06d6024e345c69e6d5a91ab5423925a68ee95a7 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Thu Oct 13 23:30:16 2022 +0200 more efficiently handle multiple conditioning commit 5efdfcbcd980ce6202ab74e7f90e7415ce7260da Merge: b9c0dc5 ac08bb6 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Thu Oct 13 14:51:01 2022 +0200 Merge branch 'optional-disable-karras-schedule' into fix-prompts commit ac08bb6fd25e19a9d35cf6c199e66500fb604af1 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Thu Oct 13 14:50:43 2022 +0200 append '*use_model_sigmas*' to prompt string to use model sigmas commit 70d8c05a3ff329409f76204f4af94e55d468ab8b Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Thu Oct 13 12:12:17 2022 +0200 make karras scheduling switchable commitd60df54f69
replaced the model's own scheduling with karras scheduling. this has changed image generation (seems worse now?) this commit wraps the change in a bool. commit b9c0dc5f1a658a0e6c3936000e9ae559e1c7a1db Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Wed Oct 12 20:16:00 2022 +0200 add test of more complex conjunction commit 9ac0c15cc0d7b5f6df3289d3ad474260972a17be Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Wed Oct 12 17:18:25 2022 +0200 improve comments commit ad33bce60590b87b2a93e90f16dc9d3e935d04a5 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Wed Oct 12 17:04:46 2022 +0200 put back thresholding stuff commit 4852c698a325049834ba0d4b358f07210bc7171a Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Wed Oct 12 14:25:02 2022 +0200 notes on improving conjunction efficiency commit a53bb1e5b68025d09642b935ae6a9a015cfaf2d6 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Wed Oct 12 14:14:33 2022 +0200 optional weights support for Conjunction commit fec79ab15e4f0c84dd61cb1b45a5e6a72ae4aaeb Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Wed Oct 12 12:07:27 2022 +0200 fix blend error and log parsing output commit 1f751c2a039f9c97af57b18e0f019512631d5a25 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Wed Oct 12 10:33:33 2022 +0200 fix broken euler sampler commit 02f8148d17efe4b6bde8d29b827092a0626363ee Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Wed Oct 12 10:24:20 2022 +0200 cleanup prompt parser commit 8028d49ae6c16c0d6ec9c9de9c12d56c32201421 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Wed Oct 12 10:14:18 2022 +0200 explicit conjunction, improve flattening logic commit 8a1710892185f07eb77483f7edae0fc4d6bbb250 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Tue Oct 11 22:59:30 2022 +0200 adapt multi-conditioning to also work with ddim commit 53802a839850d0d1ff017c6bafe457c4bed750b0 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Tue Oct 11 22:31:42 2022 +0200 unconditioning is also fancy-prompt-syntaxable commit 7c40892a9f184f7e216f14d14feb0411c5a90e24 Merge: e3f2dd6 dbe0da4 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Tue Oct 11 21:39:54 2022 +0200 Merge branch 'development-invoke' into fix-prompts commit e3f2dd62b0548ca6988818ef058093a4f5b022f2 Merge: eef0e4806f542e
Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Tue Oct 11 21:38:09 2022 +0200 Merge remote-tracking branch 'upstream/development' into fix-prompts commit eef0e484c2eaa1bd4e0e0b1d3f8d7bba38478144 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Tue Oct 11 21:26:25 2022 +0200 fix run-on paren-less attention, add some comments commit fd29afdf0e9f5e0cdc60239e22480c36ca0aaeca Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Tue Oct 11 21:03:02 2022 +0200 python 3.9 compatibility commit 26f7646eef7f39bc8f7ce805e747df0f723464da Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Tue Oct 11 20:58:42 2022 +0200 first pass connecting PromptParser to conditioning commit ae53dff3796d7b9a5e7ed30fa1edb0374af6cd8d Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Tue Oct 11 20:51:15 2022 +0200 update frontend dist commit 9be4a59a2d76f49e635474b5984bfca826a5dab4 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Tue Oct 11 19:01:39 2022 +0200 fix issues with correctness checking FlattenedPrompt commit 3be212323eab68e72a363a654124edd9809e4cf0 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Tue Oct 11 18:43:16 2022 +0200 parsing nested seems to work pretty ok commit acd73eb08cf67c27cac8a22934754321256f56a9 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Tue Oct 11 18:26:17 2022 +0200 wip introducing FlattenedPrompt class commit 71698d5c7c2ac855b690d8ef67e8830148c59eda Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Tue Oct 11 15:59:42 2022 +0200 recursive attention weighting seems to actually work commit a4e1ec6b20deb7cc0cd12737bdbd266e56144709 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Tue Oct 11 15:06:24 2022 +0200 now apparently almost supported nested attention commit da76fd1ddf22a3888cdc08fd4fed38d8b178e524 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Tue Oct 11 13:23:37 2022 +0200 wip prompt parsing commit dbe0da4572c2ac22f26a7afd722349a5680a9e47 Author: Kyle Schouviller <kyle0654@hotmail.com> Date: Mon Oct 10 22:32:35 2022 -0700 Adding node-based invocation apps commit 8f2a2ffc083366de74d7dae471b50b6f98a7c5f8 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Mon Oct 10 19:03:18 2022 +0200 fix merge issues commit 73118dee2a8f4891700756e014caf1c9ca629267 Merge: fd0084412413b0
Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Mon Oct 10 12:42:48 2022 +0200 Merge remote-tracking branch 'upstream/development' into fix-prompts commit fd0084413541013c2cf71e006af0392719bef53d Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Mon Oct 10 12:39:38 2022 +0200 wip prompt parsing commit 0be9363db9307859d2b65cffc6af01f57d7873a4 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Mon Oct 10 03:20:06 2022 +0200 better +/- attention parsing commit 5383f691874a58ab01cda1e4fac6cf330146526a Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Mon Oct 10 02:27:47 2022 +0200 prompt parser seems to work commit 591d098a33ce35462428d8c169501d8ed73615ab Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Sun Oct 9 20:25:37 2022 +0200 supports weighting unconditioning, cross-attention with | commit 7a7220563aa05a2980235b5b908362f66b728309 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Sun Oct 9 18:15:56 2022 +0200 i think cross attention might be working? commit 951ed391e7126bff228c18b2db304ad28d59644a Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Sun Oct 9 16:04:54 2022 +0200 weighted CFG denoiser working with a single item commit ee532a0c2827368c9e45a6a5f3975666402873da Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Sun Oct 9 06:33:40 2022 +0200 wip probably doesn't work or compile commit 14654bcbd207b9ca28a6cbd37dbd967d699b062d Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Fri Oct 7 18:11:48 2022 +0200 use tan() to calculate embedding weight for <1 attentions commit 1a8e76b31aa5abf5150419ebf3b29d4658d07f2b Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Fri Oct 7 16:14:54 2022 +0200 fix bad math.max reference commit f697ff896875876ccaa1e5527405bdaa7ed27cde Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Fri Oct 7 15:55:57 2022 +0200 respect http[s]x protocol when making socket.io middleware commit 41d3dd4eeae8d4efb05dfb44fc6d8aac5dc468ab Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Fri Oct 7 13:29:54 2022 +0200 fractional weighting works, by blending with prompts excluding the word commit 087fb6dfb3e8f5e84de8c911f75faa3e3fa3553c Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Fri Oct 7 10:52:03 2022 +0200 wip doing weights <1 by averaging with conditioning absent the lower-weighted fragment commit 3c49e3f3ec7c18dc60f3e18ed2f7f0d97aad3a47 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Fri Oct 7 10:36:15 2022 +0200 notate CFGDenoiser, perhaps commit d2bcf1bb522026ebf209ad0103f6b370383e5070 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Thu Oct 6 05:04:47 2022 +0200 hack blending syntax to test attention weighting more extensively commit 94904ef2cf917f74ec23ef7a570e12ff8255b048 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Thu Oct 6 04:56:37 2022 +0200 conditioning works, apparently commit 7c6663ddd70f665fd1308b6dd74f92ca393a8df5 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Thu Oct 6 02:20:24 2022 +0200 attention weighting, definitely works in positive direction commit 5856d453a9b020bc1a28ff643ae1f58c12c9be73 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Tue Oct 4 19:02:14 2022 +0200 wip bubbling weights down commit a2ed14fd9b7d3cb36b6c5348018b364c76d1e892 Author: Damian at mba <damian@frey.NOSPAMco.nz> Date: Tue Oct 4 17:35:39 2022 +0200 bring in changes from PC
This commit is contained in:
parent
92d4dfaabf
commit
c3b992db96
@ -527,7 +527,7 @@ def parameters_to_generated_image_metadata(parameters):
|
||||
rfc_dict["sampler"] = parameters["sampler_name"]
|
||||
|
||||
# display weighted subprompts (liable to change)
|
||||
subprompts = split_weighted_subprompts(parameters["prompt"])
|
||||
subprompts = split_weighted_subprompts(parameters["prompt"], skip_normalize=True)
|
||||
subprompts = [{"prompt": x[0], "weight": x[1]} for x in subprompts]
|
||||
rfc_dict["prompt"] = subprompts
|
||||
|
||||
|
@ -76,4 +76,4 @@ model:
|
||||
target: torch.nn.Identity
|
||||
|
||||
cond_stage_config:
|
||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
target: ldm.modules.encoders.modules.WeightedFrozenCLIPEmbedder
|
||||
|
@ -438,7 +438,7 @@ class Generate:
|
||||
sampler=self.sampler,
|
||||
steps=steps,
|
||||
cfg_scale=cfg_scale,
|
||||
conditioning=(uc, c),
|
||||
conditioning=(uc, c), # here change to arrays
|
||||
ddim_eta=ddim_eta,
|
||||
image_callback=image_callback, # called after the final image is generated
|
||||
step_callback=step_callback, # called after each intermediate image is generated
|
||||
@ -477,6 +477,10 @@ class Generate:
|
||||
print('**Interrupted** Partial results will be returned.')
|
||||
else:
|
||||
raise KeyboardInterrupt
|
||||
# brute-force fallback
|
||||
except Exception as e:
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print('>> Could not generate image.')
|
||||
|
||||
toc = time.time()
|
||||
print('>> Usage stats:')
|
||||
|
@ -12,42 +12,76 @@ log_tokenization() print out colour-coded tokens and warn if trunca
|
||||
import re
|
||||
import torch
|
||||
|
||||
def get_uc_and_c(prompt, model, log_tokens=False, skip_normalize=False):
|
||||
from .prompt_parser import PromptParser, Fragment, Attention, Blend, Conjunction, FlattenedPrompt
|
||||
from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder
|
||||
|
||||
|
||||
def get_uc_and_c(prompt_string_uncleaned, model, log_tokens=False, skip_normalize=False):
|
||||
|
||||
# Extract Unconditioned Words From Prompt
|
||||
unconditioned_words = ''
|
||||
unconditional_regex = r'\[(.*?)\]'
|
||||
unconditionals = re.findall(unconditional_regex, prompt)
|
||||
unconditionals = re.findall(unconditional_regex, prompt_string_uncleaned)
|
||||
|
||||
if len(unconditionals) > 0:
|
||||
unconditioned_words = ' '.join(unconditionals)
|
||||
|
||||
# Remove Unconditioned Words From Prompt
|
||||
unconditional_regex_compile = re.compile(unconditional_regex)
|
||||
clean_prompt = unconditional_regex_compile.sub(' ', prompt)
|
||||
prompt = re.sub(' +', ' ', clean_prompt)
|
||||
clean_prompt = unconditional_regex_compile.sub(' ', prompt_string_uncleaned)
|
||||
prompt_string_cleaned = re.sub(' +', ' ', clean_prompt)
|
||||
else:
|
||||
prompt_string_cleaned = prompt_string_uncleaned
|
||||
|
||||
uc = model.get_learned_conditioning([unconditioned_words])
|
||||
pp = PromptParser()
|
||||
|
||||
# get weighted sub-prompts
|
||||
weighted_subprompts = split_weighted_subprompts(
|
||||
prompt, skip_normalize
|
||||
)
|
||||
def build_conditioning_list(prompt_string:str):
|
||||
parsed_conjunction: Conjunction = pp.parse(prompt_string)
|
||||
print(f"parsed '{prompt_string}' to {parsed_conjunction}")
|
||||
assert (type(parsed_conjunction) is Conjunction)
|
||||
|
||||
conditioning_list = []
|
||||
def make_embeddings_for_flattened_prompt(flattened_prompt: FlattenedPrompt):
|
||||
if type(flattened_prompt) is not FlattenedPrompt:
|
||||
raise f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead"
|
||||
fragments = [x[0] for x in flattened_prompt.children]
|
||||
attention_weights = [x[1] for x in flattened_prompt.children]
|
||||
print(fragments, attention_weights)
|
||||
return model.get_learned_conditioning([fragments], attention_weights=[attention_weights])
|
||||
|
||||
for part,weight in zip(parsed_conjunction.prompts, parsed_conjunction.weights):
|
||||
if type(part) is Blend:
|
||||
blend:Blend = part
|
||||
embeddings_to_blend = None
|
||||
for flattened_prompt in blend.prompts:
|
||||
this_embedding = make_embeddings_for_flattened_prompt(flattened_prompt)
|
||||
embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat((embeddings_to_blend, this_embedding))
|
||||
blended_embeddings = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0), blend.weights, normalize=blend.normalize_weights)
|
||||
conditioning_list.append((blended_embeddings, weight))
|
||||
else:
|
||||
flattened_prompt: FlattenedPrompt = part
|
||||
embeddings = make_embeddings_for_flattened_prompt(flattened_prompt)
|
||||
conditioning_list.append((embeddings, weight))
|
||||
|
||||
return conditioning_list
|
||||
|
||||
positive_conditioning_list = build_conditioning_list(prompt_string_cleaned)
|
||||
negative_conditioning_list = build_conditioning_list(unconditioned_words)
|
||||
|
||||
if len(negative_conditioning_list) == 0:
|
||||
negative_conditioning = model.get_learned_conditioning([['']], attention_weights=[[1]])
|
||||
else:
|
||||
if len(negative_conditioning_list)>1:
|
||||
print("cannot do conjunctions on unconditioning for now")
|
||||
negative_conditioning = negative_conditioning_list[0][0]
|
||||
|
||||
#positive_conditioning_list.append((get_blend_prompts_and_weights(prompt), this_weight))
|
||||
#print("got empty_conditionining with shape", empty_conditioning.shape, "c[0][0] with shape", positive_conditioning[0][0].shape)
|
||||
|
||||
# "unconditioned" means "the conditioning tensor is empty"
|
||||
uc = negative_conditioning
|
||||
c = positive_conditioning_list
|
||||
|
||||
if len(weighted_subprompts) > 1:
|
||||
# i dont know if this is correct.. but it works
|
||||
c = torch.zeros_like(uc)
|
||||
# normalize each "sub prompt" and add it
|
||||
for subprompt, weight in weighted_subprompts:
|
||||
log_tokenization(subprompt, model, log_tokens, weight)
|
||||
c = torch.add(
|
||||
c,
|
||||
model.get_learned_conditioning([subprompt]),
|
||||
alpha=weight,
|
||||
)
|
||||
else: # just standard 1 prompt
|
||||
log_tokenization(prompt, model, log_tokens, 1)
|
||||
c = model.get_learned_conditioning([prompt])
|
||||
uc = model.get_learned_conditioning([unconditioned_words])
|
||||
return (uc, c)
|
||||
|
||||
def split_weighted_subprompts(text, skip_normalize=False)->list:
|
||||
|
331
ldm/invoke/prompt_parser.py
Normal file
331
ldm/invoke/prompt_parser.py
Normal file
@ -0,0 +1,331 @@
|
||||
import pyparsing
|
||||
import pyparsing as pp
|
||||
from pyparsing import original_text_for
|
||||
|
||||
|
||||
class Prompt():
|
||||
|
||||
def __init__(self, parts: list):
|
||||
for c in parts:
|
||||
allowed_types = [Fragment, Attention, CFGScale]
|
||||
if type(c) not in allowed_types:
|
||||
raise PromptParser.ParsingException(f"Prompt cannot contain {type(c)}, only {allowed_types} are allowed")
|
||||
self.children = parts
|
||||
def __repr__(self):
|
||||
return f"Prompt:{self.children}"
|
||||
def __eq__(self, other):
|
||||
return type(other) is Prompt and other.children == self.children
|
||||
|
||||
class FlattenedPrompt():
|
||||
def __init__(self, parts: list):
|
||||
# verify type correctness
|
||||
for c in parts:
|
||||
if type(c) is not tuple:
|
||||
raise PromptParser.ParsingException(
|
||||
f"FlattenedPrompt cannot contain {type(c)}, only ('text', weight) tuples are allowed")
|
||||
text = c[0]
|
||||
weight = c[1]
|
||||
if type(text) is not str:
|
||||
raise PromptParser.ParsingException(f"FlattenedPrompt cannot contain {type(c)}, only ('text', weight) tuples are allowed")
|
||||
if type(weight) is not float and type(weight) is not int:
|
||||
raise PromptParser.ParsingException(
|
||||
f"FlattenedPrompt cannot contain {type(c)}, only ('text', weight) tuples are allowed")
|
||||
# all looks good
|
||||
self.children = parts
|
||||
|
||||
def __repr__(self):
|
||||
return f"FlattenedPrompt:{self.children}"
|
||||
def __eq__(self, other):
|
||||
return type(other) is FlattenedPrompt and other.children == self.children
|
||||
|
||||
|
||||
class Attention():
|
||||
|
||||
def __init__(self, weight: float, children: list):
|
||||
self.weight = weight
|
||||
self.children = children
|
||||
#print(f"A: requested attention '{children}' to {weight}")
|
||||
|
||||
def __repr__(self):
|
||||
return f"Attention:'{self.children}' @ {self.weight}"
|
||||
def __eq__(self, other):
|
||||
return type(other) is Attention and other.weight == self.weight and other.fragment == self.fragment
|
||||
|
||||
|
||||
class CFGScale():
|
||||
def __init__(self, scale_factor: float, fragment: str):
|
||||
self.fragment = fragment
|
||||
self.scale_factor = scale_factor
|
||||
#print(f"S: requested CFGScale '{fragment}' x {scale_factor}")
|
||||
|
||||
def __repr__(self):
|
||||
return f"CFGScale:'{self.fragment}' x {self.scale_factor}"
|
||||
def __eq__(self, other):
|
||||
return type(other) is CFGScale and other.scale_factor == self.scale_factor and other.fragment == self.fragment
|
||||
|
||||
|
||||
|
||||
class Fragment():
|
||||
def __init__(self, text: str):
|
||||
assert(type(text) is str)
|
||||
self.text = text
|
||||
|
||||
def __repr__(self):
|
||||
return "Fragment:'"+self.text+"'"
|
||||
def __eq__(self, other):
|
||||
return type(other) is Fragment and other.text == self.text
|
||||
|
||||
class Conjunction():
|
||||
def __init__(self, prompts: list, weights: list = None):
|
||||
# force everything to be a Prompt
|
||||
#print("making conjunction with", parts)
|
||||
self.prompts = [x if (type(x) is Prompt or type(x) is Blend or type(x) is FlattenedPrompt)
|
||||
else Prompt(x) for x in prompts]
|
||||
self.weights = [1.0]*len(self.prompts) if weights is None else list(weights)
|
||||
if len(self.weights) != len(self.prompts):
|
||||
raise PromptParser.ParsingException(f"while parsing Conjunction: mismatched parts/weights counts {prompts}, {weights}")
|
||||
self.type = 'AND'
|
||||
|
||||
def __repr__(self):
|
||||
return f"Conjunction:{self.prompts} | weights {self.weights}"
|
||||
def __eq__(self, other):
|
||||
return type(other) is Conjunction \
|
||||
and other.prompts == self.prompts \
|
||||
and other.weights == self.weights
|
||||
|
||||
|
||||
class Blend():
|
||||
def __init__(self, prompts: list, weights: list[float], normalize_weights: bool=True):
|
||||
#print("making Blend with prompts", prompts, "and weights", weights)
|
||||
if len(prompts) != len(weights):
|
||||
raise PromptParser.ParsingException(f"while parsing Blend: mismatched prompts/weights counts {prompts}, {weights}")
|
||||
for c in prompts:
|
||||
if type(c) is not Prompt and type(c) is not FlattenedPrompt:
|
||||
raise(PromptParser.ParsingException(f"{type(c)} cannot be added to a Blend, only Prompts or FlattenedPrompts"))
|
||||
# upcast all lists to Prompt objects
|
||||
self.prompts = [x if (type(x) is Prompt or type(x) is FlattenedPrompt)
|
||||
else Prompt(x) for x in prompts]
|
||||
self.prompts = prompts
|
||||
self.weights = weights
|
||||
self.normalize_weights = normalize_weights
|
||||
|
||||
def __repr__(self):
|
||||
return f"Blend:{self.prompts} | weights {self.weights}"
|
||||
def __eq__(self, other):
|
||||
return other.__repr__() == self.__repr__()
|
||||
|
||||
|
||||
class PromptParser():
|
||||
|
||||
class ParsingException(Exception):
|
||||
pass
|
||||
|
||||
def __init__(self, attention_plus_base=1.1, attention_minus_base=0.9):
|
||||
|
||||
self.attention_plus_base = attention_plus_base
|
||||
self.attention_minus_base = attention_minus_base
|
||||
|
||||
self.root = self.build_parser_logic()
|
||||
|
||||
|
||||
def parse(self, prompt: str) -> [list]:
|
||||
'''
|
||||
:param prompt: The prompt string to parse
|
||||
:return: a tuple
|
||||
'''
|
||||
#print(f"!!parsing '{prompt}'")
|
||||
|
||||
if len(prompt.strip()) == 0:
|
||||
return Conjunction(prompts=[FlattenedPrompt([('', 1.0)])], weights=[1.0])
|
||||
|
||||
root = self.root.parse_string(prompt)
|
||||
#print(f"'{prompt}' parsed to root", root)
|
||||
#fused = fuse_fragments(parts)
|
||||
#print("fused to", fused)
|
||||
|
||||
return self.flatten(root[0])
|
||||
|
||||
def flatten(self, root: Conjunction):
|
||||
|
||||
def fuse_fragments(items):
|
||||
# print("fusing fragments in ", items)
|
||||
result = []
|
||||
for x in items:
|
||||
last_weight = result[-1][1] if len(result) > 0 else None
|
||||
this_text = x[0]
|
||||
this_weight = x[1]
|
||||
if last_weight is not None and last_weight == this_weight:
|
||||
last_text = result[-1][0]
|
||||
result[-1] = (last_text + ' ' + this_text, last_weight)
|
||||
else:
|
||||
result.append(x)
|
||||
return result
|
||||
|
||||
def flatten_internal(node, weight_scale, results, prefix):
|
||||
#print(prefix + "flattening", node, "...")
|
||||
if type(node) is pp.ParseResults:
|
||||
for x in node:
|
||||
results = flatten_internal(x, weight_scale, results, prefix+'pr')
|
||||
#print(prefix, " ParseResults expanded, results is now", results)
|
||||
elif type(node) is Fragment:
|
||||
results.append((node.text, float(weight_scale)))
|
||||
elif type(node) is Attention:
|
||||
#if node.weight < 1:
|
||||
# todo: inject a blend when flattening attention with weight <1"
|
||||
for c in node.children:
|
||||
results = flatten_internal(c, weight_scale*node.weight, results, prefix+' ')
|
||||
elif type(node) is Blend:
|
||||
flattened_subprompts = []
|
||||
#print(" flattening blend with prompts", node.prompts, "weights", node.weights)
|
||||
for prompt in node.prompts:
|
||||
# prompt is a list
|
||||
flattened_subprompts = flatten_internal(prompt, weight_scale, flattened_subprompts, prefix+'B ')
|
||||
results += [Blend(prompts=flattened_subprompts, weights=node.weights)]
|
||||
elif type(node) is Prompt:
|
||||
#print(prefix + "about to flatten Prompt with children", node.children)
|
||||
flattened_prompt = []
|
||||
for child in node.children:
|
||||
flattened_prompt = flatten_internal(child, weight_scale, flattened_prompt, prefix+'P ')
|
||||
results += [FlattenedPrompt(parts=fuse_fragments(flattened_prompt))]
|
||||
#print(prefix + "after flattening Prompt, results is", results)
|
||||
else:
|
||||
raise PromptParser.ParsingException(f"unhandled node type {type(node)} when flattening {node}")
|
||||
#print(prefix + "-> after flattening", type(node), "results is", results)
|
||||
return results
|
||||
|
||||
#print("flattening", root)
|
||||
|
||||
flattened_parts = []
|
||||
for part in root.prompts:
|
||||
flattened_parts += flatten_internal(part, 1.0, [], ' C| ')
|
||||
weights = root.weights
|
||||
return Conjunction(flattened_parts, weights)
|
||||
|
||||
|
||||
|
||||
def build_parser_logic(self):
|
||||
|
||||
lparen = pp.Literal("(").suppress()
|
||||
rparen = pp.Literal(")").suppress()
|
||||
# accepts int or float notation, always maps to float
|
||||
number = pyparsing.pyparsing_common.real | pp.Word(pp.nums).set_parse_action(pp.token_map(float))
|
||||
SPACE_CHARS = ' \t\n'
|
||||
|
||||
prompt_part = pp.Forward()
|
||||
word = pp.Forward()
|
||||
|
||||
def make_fragment(x):
|
||||
#print("### making fragment for", x)
|
||||
if type(x) is str:
|
||||
return Fragment(x)
|
||||
elif type(x) is pp.ParseResults or type(x) is list:
|
||||
return Fragment(' '.join([s for s in x]))
|
||||
else:
|
||||
raise PromptParser.ParsingException("Cannot make fragment from " + str(x))
|
||||
|
||||
# attention control of the form +(phrase) / -(phrase) / <weight>(phrase)
|
||||
# phrase can be multiple words, can have multiple +/- signs to increase the effect or type a floating point or integer weight
|
||||
attention = pp.Forward()
|
||||
attention_head = (number | pp.Word('+') | pp.Word('-'))\
|
||||
.set_name("attention_head")\
|
||||
.set_debug(False)
|
||||
fragment_inside_attention = pp.CharsNotIn(SPACE_CHARS+'()')\
|
||||
.set_parse_action(make_fragment)\
|
||||
.set_name("fragment_inside_attention")\
|
||||
.set_debug(False)
|
||||
attention_with_parens = pp.Forward()
|
||||
attention_with_parens_body = pp.nested_expr(content=pp.delimited_list((attention_with_parens | fragment_inside_attention), delim=SPACE_CHARS))
|
||||
attention_with_parens << (attention_head + attention_with_parens_body)
|
||||
|
||||
def make_attention(x):
|
||||
# print("making Attention from parsing with args", x0, x1)
|
||||
weight = 1
|
||||
# number(str)
|
||||
if type(x[0]) is float or type(x[0]) is int:
|
||||
weight = float(x[0])
|
||||
# +(str) or -(str) or +str or -str
|
||||
elif type(x[0]) is str:
|
||||
base = self.attention_plus_base if x[0][0] == '+' else self.attention_minus_base
|
||||
weight = pow(base, len(x[0]))
|
||||
# print("Making attention with children of type", [str(type(x)) for x in x1])
|
||||
return Attention(weight=weight, children=x[1])
|
||||
|
||||
attention_with_parens.set_parse_action(make_attention)\
|
||||
.set_name("attention_with_parens")\
|
||||
.set_debug(False)
|
||||
|
||||
# attention control of the form ++word --word (no parens)
|
||||
attention_without_parens = (
|
||||
(pp.Word('+') | pp.Word('-')) +
|
||||
pp.CharsNotIn(SPACE_CHARS+'()').set_parse_action(lambda x: [[make_fragment(x)]])
|
||||
)\
|
||||
.set_name("attention_without_parens")\
|
||||
.set_debug(False)
|
||||
attention_without_parens.set_parse_action(make_attention)
|
||||
|
||||
attention << (attention_with_parens | attention_without_parens)\
|
||||
.set_name("attention")\
|
||||
.set_debug(False)
|
||||
|
||||
# fragments of text with no attention control
|
||||
word << pp.Word(pp.printables).set_parse_action(lambda x: Fragment(' '.join([s for s in x])))
|
||||
word.set_name("word")
|
||||
word.set_debug(False)
|
||||
prompt_part << (attention | word)
|
||||
prompt_part.set_debug(False)
|
||||
prompt_part.set_name("prompt_part")
|
||||
|
||||
# root prompt definition
|
||||
prompt = pp.Group(pp.OneOrMore(prompt_part))\
|
||||
.set_parse_action(lambda x: Prompt(x[0]))
|
||||
|
||||
# weighted blend of prompts
|
||||
# ("promptA", "promptB").blend(a, b) where "promptA" and "promptB" are valid prompts and a and b are float or
|
||||
# int weights.
|
||||
# can specify more terms eg ("promptA", "promptB", "promptC").blend(a,b,c)
|
||||
|
||||
def make_prompt_from_quoted_string(x):
|
||||
#print(' got quoted prompt', x)
|
||||
|
||||
x_unquoted = x[0][1:-1]
|
||||
if len(x_unquoted.strip()) == 0:
|
||||
# print(' b : just an empty string')
|
||||
return Prompt([Fragment('')])
|
||||
# print(' b parsing ', c_unquoted)
|
||||
x_parsed = prompt.parse_string(x_unquoted)
|
||||
#print(" quoted prompt was parsed to", type(x_parsed),":", x_parsed)
|
||||
return x_parsed[0]
|
||||
|
||||
quoted_prompt = pp.dbl_quoted_string.set_parse_action(make_prompt_from_quoted_string)
|
||||
quoted_prompt.set_name('quoted_prompt')
|
||||
|
||||
blend_terms = pp.delimited_list(quoted_prompt).set_name('blend_terms')
|
||||
blend_weights = pp.delimited_list(number).set_name('blend_weights')
|
||||
blend = pp.Group(lparen + pp.Group(blend_terms) + rparen
|
||||
+ pp.Literal(".blend").suppress()
|
||||
+ lparen + pp.Group(blend_weights) + rparen).set_name('blend')
|
||||
blend.set_debug(False)
|
||||
|
||||
|
||||
blend.set_parse_action(lambda x: Blend(prompts=x[0][0], weights=x[0][1]))
|
||||
|
||||
conjunction_terms = blend_terms.copy().set_name('conjunction_terms')
|
||||
conjunction_weights = blend_weights.copy().set_name('conjunction_weights')
|
||||
conjunction_with_parens_and_quotes = pp.Group(lparen + pp.Group(conjunction_terms) + rparen
|
||||
+ pp.Literal(".and").suppress()
|
||||
+ lparen + pp.Optional(pp.Group(conjunction_weights)) + rparen).set_name('conjunction')
|
||||
def make_conjunction(x):
|
||||
parts_raw = x[0][0]
|
||||
weights = x[0][1] if len(x[0])>1 else [1.0]*len(parts_raw)
|
||||
parts = [part for part in parts_raw]
|
||||
return Conjunction(parts, weights)
|
||||
conjunction_with_parens_and_quotes.set_parse_action(make_conjunction)
|
||||
|
||||
implicit_conjunction = pp.OneOrMore(blend | prompt)
|
||||
implicit_conjunction.set_parse_action(lambda x: Conjunction(x))
|
||||
|
||||
conjunction = conjunction_with_parens_and_quotes | implicit_conjunction
|
||||
conjunction.set_debug(False)
|
||||
|
||||
# top-level is a conjunction of one or more blends or prompts
|
||||
return conjunction
|
@ -34,23 +34,21 @@ class DDIMSampler(Sampler):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
if (
|
||||
unconditional_conditioning is None
|
||||
or unconditional_guidance_scale == 1.0
|
||||
(unconditional_conditioning is None
|
||||
or unconditional_guidance_scale == 1.0)
|
||||
and c is not list
|
||||
):
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||
e_t = e_t_uncond + unconditional_guidance_scale * (
|
||||
e_t - e_t_uncond
|
||||
)
|
||||
e_t = self.apply_weighted_conditioning_list(x, t, self.model.apply_model, unconditional_conditioning, c, unconditional_guidance_scale)
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == 'eps'
|
||||
if c is list and len(c)>1:
|
||||
print("warning: ddim score modifier currently ignores all but the first part of the prompt conjunction, this is probably wrong")
|
||||
corrector_c = [c[0][0] if c is list else c]
|
||||
e_t = score_corrector.modify_score(
|
||||
self.model, e_t, x, t, c, **corrector_kwargs
|
||||
self.model, e_t, x, t, corrector_c, **corrector_kwargs
|
||||
)
|
||||
|
||||
alphas = (
|
||||
|
@ -820,13 +820,13 @@ class LatentDiffusion(DDPM):
|
||||
)
|
||||
return self.scale_factor * z
|
||||
|
||||
def get_learned_conditioning(self, c):
|
||||
def get_learned_conditioning(self, c, attention_weights=None):
|
||||
if self.cond_stage_forward is None:
|
||||
if hasattr(self.cond_stage_model, 'encode') and callable(
|
||||
self.cond_stage_model.encode
|
||||
):
|
||||
c = self.cond_stage_model.encode(
|
||||
c, embedding_manager=self.embedding_manager
|
||||
c, embedding_manager=self.embedding_manager, attention_weights=attention_weights
|
||||
)
|
||||
if isinstance(c, DiagonalGaussianDistribution):
|
||||
c = c.mode()
|
||||
|
@ -38,7 +38,8 @@ class CFGDenoiser(nn.Module):
|
||||
x_in = torch.cat([x] * 2)
|
||||
sigma_in = torch.cat([sigma] * 2)
|
||||
cond_in = torch.cat([uncond, cond])
|
||||
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
|
||||
unconditioned_x, conditioned_x = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
|
||||
|
||||
if self.warmup < self.warmup_max:
|
||||
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
|
||||
self.warmup += 1
|
||||
@ -46,7 +47,28 @@ class CFGDenoiser(nn.Module):
|
||||
thresh = self.threshold
|
||||
if thresh > self.threshold:
|
||||
thresh = self.threshold
|
||||
return cfg_apply_threshold(uncond + (cond - uncond) * cond_scale, thresh)
|
||||
|
||||
# damian0815 thinking out loud notes:
|
||||
# b + (a - b)*scale
|
||||
# starting at the output that emerges applying the negative prompt (by default ''),
|
||||
# (-> this is why the unconditioning feels like hammer)
|
||||
# move toward the positive prompt by an amount controlled by cond_scale.
|
||||
return cfg_apply_threshold(unconditioned_x + (conditioned_x - unconditioned_x) * cond_scale, thresh)
|
||||
|
||||
|
||||
class ProgrammableCFGDenoiser(CFGDenoiser):
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||
forward_lambda = lambda x, t, c: self.inner_model(x, t, cond=c)
|
||||
x_new = Sampler.apply_weighted_conditioning_list(x, sigma, forward_lambda, uncond, cond, cond_scale)
|
||||
|
||||
if self.warmup < self.warmup_max:
|
||||
thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max))
|
||||
self.warmup += 1
|
||||
else:
|
||||
thresh = self.threshold
|
||||
if thresh > self.threshold:
|
||||
thresh = self.threshold
|
||||
return cfg_apply_threshold(x_new, threshold=thresh)
|
||||
|
||||
|
||||
class KSampler(Sampler):
|
||||
@ -181,7 +203,6 @@ class KSampler(Sampler):
|
||||
)
|
||||
|
||||
# sigmas are set up in make_schedule - we take the last steps items
|
||||
total_steps = len(self.sigmas)
|
||||
sigmas = self.sigmas[-S-1:]
|
||||
|
||||
# x_T is variation noise. When an init image is provided (in x0) we need to add
|
||||
@ -194,7 +215,7 @@ class KSampler(Sampler):
|
||||
else:
|
||||
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0]
|
||||
|
||||
model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
|
||||
model_wrap_cfg = ProgrammableCFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10))
|
||||
extra_args = {
|
||||
'cond': conditioning,
|
||||
'uncond': unconditional_conditioning,
|
||||
|
@ -4,6 +4,8 @@ ldm.models.diffusion.sampler
|
||||
Base class for ldm.models.diffusion.ddim, ldm.models.diffusion.ksampler, etc
|
||||
|
||||
'''
|
||||
from math import ceil
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
@ -411,3 +413,54 @@ class Sampler(object):
|
||||
return self.model.inner_model.q_sample(x0,ts)
|
||||
'''
|
||||
return self.model.q_sample(x0,ts)
|
||||
|
||||
|
||||
@classmethod
|
||||
def apply_weighted_conditioning_list(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale):
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2) # aka sigmas
|
||||
|
||||
deltas = None
|
||||
uncond_latents = None
|
||||
weighted_cond_list = c_or_weighted_c_list if type(c_or_weighted_c_list) is list else [(c_or_weighted_c_list, 1)]
|
||||
|
||||
# below is fugly omg
|
||||
num_actual_conditionings = len(c_or_weighted_c_list)
|
||||
conditionings = [uc] + [c for c,weight in weighted_cond_list]
|
||||
weights = [1] + [weight for c,weight in weighted_cond_list]
|
||||
chunk_count = ceil(len(conditionings)/2)
|
||||
assert(len(conditionings)>=2, "need at least one uncond and one cond")
|
||||
deltas = None
|
||||
for chunk_index in range(chunk_count):
|
||||
offset = chunk_index*2
|
||||
chunk_size = min(2, len(conditionings)-offset)
|
||||
|
||||
if chunk_size == 1:
|
||||
c_in = conditionings[offset]
|
||||
latents_a = forward_func(x_in[:-1], t_in[:-1], c_in)
|
||||
latents_b = None
|
||||
else:
|
||||
c_in = torch.cat(conditionings[offset:offset+2])
|
||||
latents_a, latents_b = forward_func(x_in, t_in, c_in).chunk(2)
|
||||
|
||||
# first chunk is guaranteed to be 2 entries: uncond_latents + first conditioining
|
||||
if chunk_index == 0:
|
||||
uncond_latents = latents_a
|
||||
deltas = latents_b - uncond_latents
|
||||
else:
|
||||
deltas = torch.cat((deltas, latents_a - uncond_latents))
|
||||
if latents_b is not None:
|
||||
deltas = torch.cat((deltas, latents_b - uncond_latents))
|
||||
|
||||
# merge the weighted deltas together into a single merged delta
|
||||
per_delta_weights = torch.tensor(weights[1:], dtype=deltas.dtype, device=deltas.device)
|
||||
normalize = False
|
||||
if normalize:
|
||||
per_delta_weights /= torch.sum(per_delta_weights)
|
||||
reshaped_weights = per_delta_weights.reshape(per_delta_weights.shape + (1, 1, 1))
|
||||
deltas_merged = torch.sum(deltas * reshaped_weights, dim=0, keepdim=True)
|
||||
|
||||
# old_return_value = super().forward(x, sigma, uncond, cond, cond_scale)
|
||||
# assert(0 == len(torch.nonzero(old_return_value - (uncond_latents + deltas_merged * cond_scale))))
|
||||
|
||||
return uncond_latents + deltas_merged * global_guidance_scale
|
||||
|
@ -1,3 +1,5 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from functools import partial
|
||||
@ -454,6 +456,207 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
def encode(self, text, **kwargs):
|
||||
return self(text, **kwargs)
|
||||
|
||||
class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder):
|
||||
|
||||
attention_weights_key = "attention_weights"
|
||||
|
||||
def build_token_list_fragment(self, fragment: str, weight: float) -> (torch.Tensor, torch.Tensor):
|
||||
batch_encoding = self.tokenizer(
|
||||
fragment,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_length=True,
|
||||
return_overflowing_tokens=False,
|
||||
padding='none',
|
||||
return_tensors='pt',
|
||||
)
|
||||
return batch_encoding, torch.ones_like(batch_encoding) * weight
|
||||
|
||||
|
||||
def forward(self, text: list, **kwargs):
|
||||
'''
|
||||
|
||||
:param text: A batch of prompt strings, or, a batch of lists of fragments of prompt strings to which different
|
||||
weights shall be applied.
|
||||
:param kwargs: If the keyword arg "attention_weights" is passed, it shall contain a batch of lists of weights
|
||||
for the prompt fragments. In this case text must contain batches of lists of prompt fragments.
|
||||
:return: A tensor of shape (B, 77, 768) containing weighted embeddings
|
||||
'''
|
||||
if self.attention_weights_key not in kwargs:
|
||||
# fallback to base class implementation
|
||||
return super().forward(text, **kwargs)
|
||||
|
||||
attention_weights = kwargs[self.attention_weights_key]
|
||||
# self.transformer doesn't like receiving "attention_weights" as an argument
|
||||
kwargs.pop(self.attention_weights_key)
|
||||
|
||||
batch_z = None
|
||||
for fragments, weights in zip(text, attention_weights):
|
||||
|
||||
# First, weight tokens in individual fragments by scaling the feature vectors as requested (effectively
|
||||
# applying a multiplier to the CFG scale on a per-token basis).
|
||||
# For tokens weighted<1, intuitively we want SD to become not merely *less* interested in the concept
|
||||
# captured by the fragment but actually *dis*interested in it (a 0.01 interest in "red" is still an active
|
||||
# interest, however small, in redness; what the user probably intends when they attach the number 0.01 to
|
||||
# "red" is to tell SD that it should almost completely *ignore* redness).
|
||||
# To do this, the embedding is lerped away from base_embedding in the direction of an embedding for a prompt
|
||||
# string from which the low-weighted fragment has been simply removed. The closer the weight is to zero, the
|
||||
# closer the resulting embedding is to an embedding for a prompt that simply lacks this fragment.
|
||||
|
||||
# handle weights >=1
|
||||
tokens, per_token_weights = self.get_tokens_and_weights(fragments, weights)
|
||||
base_embedding = self.build_weighted_embedding_tensor(tokens, per_token_weights, **kwargs)
|
||||
|
||||
# this is our starting point
|
||||
embeddings = base_embedding.unsqueeze(0)
|
||||
per_embedding_weights = [1.0]
|
||||
|
||||
# now handle weights <1
|
||||
# Do this by building extra embeddings tensors that lack the words being <1 weighted. These will be lerped
|
||||
# with the embeddings tensors that have the words, such that if the weight of a word is 0.5, the resulting
|
||||
# embedding will be exactly half-way between the unweighted prompt and the prompt with the <1 weighted words
|
||||
# removed.
|
||||
# eg for "mountain:1 man:0.5", intuitively the "man" should be "half-gone". therefore, append an embedding
|
||||
# for "mountain" (i.e. without "man") to the already-produced embedding for "mountain man", and weight it
|
||||
# such that the resulting lerped embedding is exactly half-way between "mountain man" and "mountain".
|
||||
for index, fragment_weight in enumerate(weights):
|
||||
if fragment_weight < 1:
|
||||
fragments_without_this = fragments[:index] + fragments[index+1:]
|
||||
weights_without_this = weights[:index] + weights[index+1:]
|
||||
tokens, per_token_weights = self.get_tokens_and_weights(fragments_without_this, weights_without_this)
|
||||
embedding_without_this = self.build_weighted_embedding_tensor(tokens, per_token_weights, **kwargs)
|
||||
|
||||
embeddings = torch.cat((embeddings, embedding_without_this.unsqueeze(0)), dim=1)
|
||||
# weight of the embedding *without* this fragment gets *stronger* as its weight approaches 0
|
||||
# if fragment_weight = 0, basically we want embedding_without_this to completely overwhelm base_embedding
|
||||
# therefore:
|
||||
# fragment_weight = 1: we are at base_z => lerp weight 0
|
||||
# fragment_weight = 0.5: we are halfway between base_z and here => lerp weight 1
|
||||
# fragment_weight = 0: we're now entirely overriding base_z ==> lerp weight inf
|
||||
# so let's use tan(), because:
|
||||
# tan is 0.0 at 0,
|
||||
# 1.0 at PI/4, and
|
||||
# inf at PI/2
|
||||
# -> tan((1-weight)*PI/2) should give us ideal lerp weights
|
||||
epsilon = 1e-9
|
||||
fragment_weight = max(epsilon, fragment_weight) # inf is bad
|
||||
embedding_lerp_weight = math.tan((1.0 - fragment_weight) * math.pi / 2)
|
||||
# todo handle negative weight?
|
||||
|
||||
per_embedding_weights.append(embedding_lerp_weight)
|
||||
|
||||
lerped_embeddings = self.apply_embedding_weights(embeddings, per_embedding_weights, normalize=True).squeeze(0)
|
||||
|
||||
print(f"assembled tokens for '{fragments}' into tensor of shape {lerped_embeddings.shape}")
|
||||
|
||||
# append to batch
|
||||
batch_z = lerped_embeddings.unsqueeze(0) if batch_z is None else torch.cat((batch_z, lerped_embeddings.unsqueeze(0)), dim=1)
|
||||
|
||||
# should have shape (B, 77, 768)
|
||||
print(f"assembled all tokens into tensor of shape {batch_z.shape}")
|
||||
|
||||
return batch_z
|
||||
|
||||
@classmethod
|
||||
def apply_embedding_weights(self, embeddings: torch.Tensor, per_embedding_weights: list[float], normalize:bool) -> torch.Tensor:
|
||||
per_embedding_weights = torch.tensor(per_embedding_weights, dtype=embeddings.dtype, device=embeddings.device)
|
||||
if normalize:
|
||||
per_embedding_weights = per_embedding_weights / torch.sum(per_embedding_weights)
|
||||
reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1, 1,))
|
||||
#reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1,1,)).expand(embeddings.shape)
|
||||
return torch.sum(embeddings * reshaped_weights, dim=1)
|
||||
# lerped embeddings has shape (77, 768)
|
||||
|
||||
|
||||
def get_tokens_and_weights(self, fragments: list[str], weights: list[float]) -> (torch.Tensor, torch.Tensor):
|
||||
'''
|
||||
|
||||
:param fragments:
|
||||
:param weights: Per-fragment weights (CFG scaling). No need for these to be normalized. They will not be normalized here and that's fine.
|
||||
:return:
|
||||
'''
|
||||
# empty is meaningful
|
||||
if len(fragments) == 0 and len(weights) == 0:
|
||||
fragments = ['']
|
||||
weights = [1]
|
||||
item_encodings = self.tokenizer(
|
||||
fragments,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_overflowing_tokens=False,
|
||||
padding='do_not_pad',
|
||||
return_tensors=None, # just give me a list of ints
|
||||
)['input_ids']
|
||||
all_tokens = []
|
||||
per_token_weights = []
|
||||
print("all fragments:", fragments, weights)
|
||||
for index, fragment in enumerate(item_encodings):
|
||||
weight = weights[index]
|
||||
print("processing fragment", fragment, weight)
|
||||
fragment_tokens = item_encodings[index]
|
||||
print("fragment", fragment, "processed to", fragment_tokens)
|
||||
# trim bos and eos markers before appending
|
||||
all_tokens.extend(fragment_tokens[1:-1])
|
||||
per_token_weights.extend([weight] * (len(fragment_tokens) - 2))
|
||||
|
||||
if len(all_tokens) > self.max_length - 2:
|
||||
print("prompt is too long and has been truncated")
|
||||
all_tokens = all_tokens[:self.max_length - 2]
|
||||
|
||||
# pad out to a 77-entry array: [eos_token, <prompt tokens>, eos_token, ..., eos_token]
|
||||
# (77 = self.max_length)
|
||||
pad_length = self.max_length - 1 - len(all_tokens)
|
||||
all_tokens.insert(0, self.tokenizer.bos_token_id)
|
||||
all_tokens.extend([self.tokenizer.eos_token_id] * pad_length)
|
||||
per_token_weights.insert(0, 1)
|
||||
per_token_weights.extend([1] * pad_length)
|
||||
|
||||
all_tokens_tensor = torch.tensor(all_tokens, dtype=torch.long).to(self.device)
|
||||
per_token_weights_tensor = torch.tensor(per_token_weights, dtype=torch.float32).to(self.device)
|
||||
print(f"assembled all_tokens_tensor with shape {all_tokens_tensor.shape}")
|
||||
return all_tokens_tensor, per_token_weights_tensor
|
||||
|
||||
def build_weighted_embedding_tensor(self, tokens: torch.Tensor, per_token_weights: torch.Tensor, weight_delta_from_empty=True, **kwargs) -> torch.Tensor:
|
||||
'''
|
||||
Build a tensor representing the passed-in tokens, each of which has a weight.
|
||||
:param tokens: A tensor of shape (77) containing token ids (integers)
|
||||
:param per_token_weights: A tensor of shape (77) containing weights (floats)
|
||||
:param method: Whether to multiply the whole feature vector for each token or just its distance from an "empty" feature vector
|
||||
:param kwargs: passed on to self.transformer()
|
||||
:return: A tensor of shape (1, 77, 768) representing the requested weighted embeddings.
|
||||
'''
|
||||
#print(f"building weighted embedding tensor for {tokens} with weights {per_token_weights}")
|
||||
z = self.transformer(input_ids=tokens.unsqueeze(0), **kwargs)
|
||||
batch_weights_expanded = per_token_weights.reshape(per_token_weights.shape + (1,)).expand(z.shape)
|
||||
|
||||
if weight_delta_from_empty:
|
||||
empty_tokens = self.tokenizer([''] * z.shape[0],
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
padding='max_length',
|
||||
return_tensors='pt'
|
||||
)['input_ids'].to(self.device)
|
||||
empty_z = self.transformer(input_ids=empty_tokens, **kwargs)
|
||||
z_delta_from_empty = z - empty_z
|
||||
weighted_z = empty_z + (z_delta_from_empty * batch_weights_expanded)
|
||||
|
||||
weighted_z_delta_from_empty = (weighted_z-empty_z)
|
||||
print("weighted z has delta from empty with sum", weighted_z_delta_from_empty.sum().item(), "mean", weighted_z_delta_from_empty.mean().item() )
|
||||
|
||||
#print("using empty-delta method, first 5 rows:")
|
||||
#print(weighted_z[:5])
|
||||
|
||||
return weighted_z
|
||||
|
||||
else:
|
||||
original_mean = z.mean()
|
||||
z *= batch_weights_expanded
|
||||
after_weighting_mean = z.mean()
|
||||
# correct the mean. not sure if this is right but it's what the automatic1111 fork of SD does
|
||||
mean_correction_factor = original_mean/after_weighting_mean
|
||||
z *= mean_correction_factor
|
||||
return z
|
||||
|
||||
|
||||
class FrozenCLIPTextEmbedder(nn.Module):
|
||||
"""
|
||||
|
136
tests/test_prompt_parser.py
Normal file
136
tests/test_prompt_parser.py
Normal file
@ -0,0 +1,136 @@
|
||||
import unittest
|
||||
|
||||
from ldm.invoke.prompt_parser import PromptParser, Blend, Conjunction, FlattenedPrompt
|
||||
|
||||
def parse_prompt(prompt_string):
|
||||
pp = PromptParser()
|
||||
#print(f"parsing '{prompt_string}'")
|
||||
parse_result = pp.parse(prompt_string)
|
||||
#print(f"-> parsed '{prompt_string}' to {parse_result}")
|
||||
return parse_result
|
||||
|
||||
class PromptParserTestCase(unittest.TestCase):
|
||||
|
||||
def test_empty(self):
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('', 1)])]), parse_prompt(''))
|
||||
|
||||
def test_basic(self):
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('fire (flames)', 1)])]), parse_prompt("fire (flames)"))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([("fire flames", 1)])]), parse_prompt("fire flames"))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([("fire, flames", 1)])]), parse_prompt("fire, flames"))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([("fire, flames , fire", 1)])]), parse_prompt("fire, flames , fire"))
|
||||
|
||||
def test_attention(self):
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('flames', 0.5)])]), parse_prompt("0.5(flames)"))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('fire flames', 0.5)])]), parse_prompt("0.5(fire flames)"))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('flames', 1.1)])]), parse_prompt("+(flames)"))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('flames', 0.9)])]), parse_prompt("-(flames)"))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1), ('flames', 0.5)])]), parse_prompt("fire 0.5(flames)"))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('flames', pow(1.1, 2))])]), parse_prompt("++(flames)"))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('flames', pow(0.9, 2))])]), parse_prompt("--(flames)"))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('flowers', pow(0.9, 3)), ('flames', pow(1.1, 3))])]), parse_prompt("---(flowers) +++flames"))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('flowers', pow(0.9, 3)), ('flames', pow(1.1, 3))])]), parse_prompt("---(flowers) +++flames"))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('flowers', pow(0.9, 3)), ('flames+', pow(1.1, 3))])]),
|
||||
parse_prompt("---(flowers) +++flames+"))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('pretty flowers', 1.1)])]),
|
||||
parse_prompt("+(pretty flowers)"))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('pretty flowers', 1.1), (', the flames are too hot', 1)])]),
|
||||
parse_prompt("+(pretty flowers), the flames are too hot"))
|
||||
|
||||
def test_no_parens_attention_runon(self):
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('fire', pow(1.1, 2)), ('flames', 1.0)])]), parse_prompt("++fire flames"))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('fire', pow(0.9, 2)), ('flames', 1.0)])]), parse_prompt("--fire flames"))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('flowers', 1.0), ('fire', pow(1.1, 2)), ('flames', 1.0)])]), parse_prompt("flowers ++fire flames"))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('flowers', 1.0), ('fire', pow(0.9, 2)), ('flames', 1.0)])]), parse_prompt("flowers --fire flames"))
|
||||
|
||||
|
||||
def test_explicit_conjunction(self):
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])]), parse_prompt('("fire", "flames").and(1,1)'))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])]), parse_prompt('("fire", "flames").and()'))
|
||||
self.assertEqual(
|
||||
Conjunction([FlattenedPrompt([('fire flames', 1.0)]), FlattenedPrompt([('mountain man', 1.0)])]), parse_prompt('("fire flames", "mountain man").and()'))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 2.0)]), FlattenedPrompt([('flames', 0.9)])]), parse_prompt('("2.0(fire)", "-flames").and()'))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)]),
|
||||
FlattenedPrompt([('mountain man', 1.0)])]), parse_prompt('("fire", "flames", "mountain man").and()'))
|
||||
|
||||
def test_conjunction_weights(self):
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])], weights=[2.0,1.0]), parse_prompt('("fire", "flames").and(2,1)'))
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])], weights=[1.0,2.0]), parse_prompt('("fire", "flames").and(1,2)'))
|
||||
|
||||
with self.assertRaises(PromptParser.ParsingException):
|
||||
parse_prompt('("fire", "flames").and(2)')
|
||||
parse_prompt('("fire", "flames").and(2,1,2)')
|
||||
|
||||
def test_complex_conjunction(self):
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([("mountain man", 1.0)]), FlattenedPrompt([("a person with a hat", 1.0), ("riding a bicycle", pow(1.1,2))])], weights=[0.5, 0.5]),
|
||||
parse_prompt("(\"mountain man\", \"a person with a hat ++(riding a bicycle)\").and(0.5, 0.5)"))
|
||||
|
||||
def test_badly_formed(self):
|
||||
def make_untouched_prompt(prompt):
|
||||
return Conjunction([FlattenedPrompt([(prompt, 1.0)])])
|
||||
|
||||
def assert_if_prompt_string_not_untouched(prompt):
|
||||
self.assertEqual(make_untouched_prompt(prompt), parse_prompt(prompt))
|
||||
|
||||
assert_if_prompt_string_not_untouched('a test prompt')
|
||||
assert_if_prompt_string_not_untouched('a badly (formed test prompt')
|
||||
assert_if_prompt_string_not_untouched('a badly formed test+ prompt')
|
||||
assert_if_prompt_string_not_untouched('a badly (formed test+ prompt')
|
||||
assert_if_prompt_string_not_untouched('a badly (formed test+ )prompt')
|
||||
assert_if_prompt_string_not_untouched('a badly (formed test+ )prompt')
|
||||
assert_if_prompt_string_not_untouched('(((a badly (formed test+ )prompt')
|
||||
assert_if_prompt_string_not_untouched('(a (ba)dly (f)ormed test+ prompt')
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('(a (ba)dly (f)ormed test+', 1.0), ('prompt', 1.1)])]),
|
||||
parse_prompt('(a (ba)dly (f)ormed test+ +prompt'))
|
||||
self.assertEqual(Conjunction([Blend([FlattenedPrompt([('((a badly (formed test+', 1.0)])], weights=[1.0])]),
|
||||
parse_prompt('("((a badly (formed test+ ").blend(1.0)'))
|
||||
|
||||
def test_blend(self):
|
||||
self.assertEqual(Conjunction(
|
||||
[Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('fire flames', 1.0)])], [0.7, 0.3])]),
|
||||
parse_prompt("(\"fire\", \"fire flames\").blend(0.7, 0.3)")
|
||||
)
|
||||
self.assertEqual(Conjunction([Blend(
|
||||
[FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('fire flames', 1.0)]), FlattenedPrompt([('hi', 1.0)])],
|
||||
[0.7, 0.3, 1.0])]),
|
||||
parse_prompt("(\"fire\", \"fire flames\", \"hi\").blend(0.7, 0.3, 1.0)")
|
||||
)
|
||||
self.assertEqual(Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]),
|
||||
FlattenedPrompt([('fire flames', 1.0), ('hot', pow(1.1, 2))]),
|
||||
FlattenedPrompt([('hi', 1.0)])],
|
||||
weights=[0.7, 0.3, 1.0])]),
|
||||
parse_prompt("(\"fire\", \"fire flames ++(hot)\", \"hi\").blend(0.7, 0.3, 1.0)")
|
||||
)
|
||||
# blend a single entry is not a failure
|
||||
self.assertEqual(Conjunction([Blend([FlattenedPrompt([('fire', 1.0)])], [0.7])]),
|
||||
parse_prompt("(\"fire\").blend(0.7)")
|
||||
)
|
||||
# blend with empty
|
||||
self.assertEqual(
|
||||
Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]),
|
||||
parse_prompt("(\"fire\", \"\").blend(0.7, 1)")
|
||||
)
|
||||
self.assertEqual(
|
||||
Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]),
|
||||
parse_prompt("(\"fire\", \" \").blend(0.7, 1)")
|
||||
)
|
||||
self.assertEqual(
|
||||
Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]),
|
||||
parse_prompt("(\"fire\", \" \").blend(0.7, 1)")
|
||||
)
|
||||
self.assertEqual(
|
||||
Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([(',', 1.0)])], [0.7, 1.0])]),
|
||||
parse_prompt("(\"fire\", \" , \").blend(0.7, 1)")
|
||||
)
|
||||
|
||||
|
||||
def test_nested(self):
|
||||
self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0), ('flames', 2.0), ('trees', 3.0)])]),
|
||||
parse_prompt('fire 2.0(flames 1.5(trees))'))
|
||||
self.assertEqual(Conjunction([Blend(prompts=[FlattenedPrompt([('fire', 1.0), ('flames', 1.2100000000000002)]),
|
||||
FlattenedPrompt([('mountain', 1.0), ('man', 2.0)])],
|
||||
weights=[1.0, 1.0])]),
|
||||
parse_prompt('("fire ++(flames)", "mountain 2(man)").blend(1,1)'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in New Issue
Block a user