From c3b992db968984fc3bfa4f1311461dc0177eb8af Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Sat, 15 Oct 2022 23:44:54 +0200 Subject: [PATCH 1/4] Squashed commit of the following: commit 9bb0b5d0036c4dffbb72ce11e097fae4ab63defd Author: Damian at mba Date: Sat Oct 15 23:43:41 2022 +0200 undo local_files_only stuff commit eed93f5d30c34cfccaf7497618ae9af17a5ecfbb Author: Damian at mba Date: Sat Oct 15 23:40:37 2022 +0200 Revert "Merge branch 'development-invoke' into fix-prompts" This reverts commit 7c40892a9f184f7e216f14d14feb0411c5a90e24, reversing changes made to e3f2dd62b0548ca6988818ef058093a4f5b022f2. commit f06d6024e345c69e6d5a91ab5423925a68ee95a7 Author: Damian at mba Date: Thu Oct 13 23:30:16 2022 +0200 more efficiently handle multiple conditioning commit 5efdfcbcd980ce6202ab74e7f90e7415ce7260da Merge: b9c0dc5 ac08bb6 Author: Damian at mba Date: Thu Oct 13 14:51:01 2022 +0200 Merge branch 'optional-disable-karras-schedule' into fix-prompts commit ac08bb6fd25e19a9d35cf6c199e66500fb604af1 Author: Damian at mba Date: Thu Oct 13 14:50:43 2022 +0200 append '*use_model_sigmas*' to prompt string to use model sigmas commit 70d8c05a3ff329409f76204f4af94e55d468ab8b Author: Damian at mba Date: Thu Oct 13 12:12:17 2022 +0200 make karras scheduling switchable commit d60df54f69968e2fb22809c55e23b3c02f37ad63 replaced the model's own scheduling with karras scheduling. this has changed image generation (seems worse now?) this commit wraps the change in a bool. commit b9c0dc5f1a658a0e6c3936000e9ae559e1c7a1db Author: Damian at mba Date: Wed Oct 12 20:16:00 2022 +0200 add test of more complex conjunction commit 9ac0c15cc0d7b5f6df3289d3ad474260972a17be Author: Damian at mba Date: Wed Oct 12 17:18:25 2022 +0200 improve comments commit ad33bce60590b87b2a93e90f16dc9d3e935d04a5 Author: Damian at mba Date: Wed Oct 12 17:04:46 2022 +0200 put back thresholding stuff commit 4852c698a325049834ba0d4b358f07210bc7171a Author: Damian at mba Date: Wed Oct 12 14:25:02 2022 +0200 notes on improving conjunction efficiency commit a53bb1e5b68025d09642b935ae6a9a015cfaf2d6 Author: Damian at mba Date: Wed Oct 12 14:14:33 2022 +0200 optional weights support for Conjunction commit fec79ab15e4f0c84dd61cb1b45a5e6a72ae4aaeb Author: Damian at mba Date: Wed Oct 12 12:07:27 2022 +0200 fix blend error and log parsing output commit 1f751c2a039f9c97af57b18e0f019512631d5a25 Author: Damian at mba Date: Wed Oct 12 10:33:33 2022 +0200 fix broken euler sampler commit 02f8148d17efe4b6bde8d29b827092a0626363ee Author: Damian at mba Date: Wed Oct 12 10:24:20 2022 +0200 cleanup prompt parser commit 8028d49ae6c16c0d6ec9c9de9c12d56c32201421 Author: Damian at mba Date: Wed Oct 12 10:14:18 2022 +0200 explicit conjunction, improve flattening logic commit 8a1710892185f07eb77483f7edae0fc4d6bbb250 Author: Damian at mba Date: Tue Oct 11 22:59:30 2022 +0200 adapt multi-conditioning to also work with ddim commit 53802a839850d0d1ff017c6bafe457c4bed750b0 Author: Damian at mba Date: Tue Oct 11 22:31:42 2022 +0200 unconditioning is also fancy-prompt-syntaxable commit 7c40892a9f184f7e216f14d14feb0411c5a90e24 Merge: e3f2dd6 dbe0da4 Author: Damian at mba Date: Tue Oct 11 21:39:54 2022 +0200 Merge branch 'development-invoke' into fix-prompts commit e3f2dd62b0548ca6988818ef058093a4f5b022f2 Merge: eef0e48 06f542e Author: Damian at mba Date: Tue Oct 11 21:38:09 2022 +0200 Merge remote-tracking branch 'upstream/development' into fix-prompts commit eef0e484c2eaa1bd4e0e0b1d3f8d7bba38478144 Author: Damian at mba Date: Tue Oct 11 21:26:25 2022 +0200 fix run-on paren-less attention, add some comments commit fd29afdf0e9f5e0cdc60239e22480c36ca0aaeca Author: Damian at mba Date: Tue Oct 11 21:03:02 2022 +0200 python 3.9 compatibility commit 26f7646eef7f39bc8f7ce805e747df0f723464da Author: Damian at mba Date: Tue Oct 11 20:58:42 2022 +0200 first pass connecting PromptParser to conditioning commit ae53dff3796d7b9a5e7ed30fa1edb0374af6cd8d Author: Damian at mba Date: Tue Oct 11 20:51:15 2022 +0200 update frontend dist commit 9be4a59a2d76f49e635474b5984bfca826a5dab4 Author: Damian at mba Date: Tue Oct 11 19:01:39 2022 +0200 fix issues with correctness checking FlattenedPrompt commit 3be212323eab68e72a363a654124edd9809e4cf0 Author: Damian at mba Date: Tue Oct 11 18:43:16 2022 +0200 parsing nested seems to work pretty ok commit acd73eb08cf67c27cac8a22934754321256f56a9 Author: Damian at mba Date: Tue Oct 11 18:26:17 2022 +0200 wip introducing FlattenedPrompt class commit 71698d5c7c2ac855b690d8ef67e8830148c59eda Author: Damian at mba Date: Tue Oct 11 15:59:42 2022 +0200 recursive attention weighting seems to actually work commit a4e1ec6b20deb7cc0cd12737bdbd266e56144709 Author: Damian at mba Date: Tue Oct 11 15:06:24 2022 +0200 now apparently almost supported nested attention commit da76fd1ddf22a3888cdc08fd4fed38d8b178e524 Author: Damian at mba Date: Tue Oct 11 13:23:37 2022 +0200 wip prompt parsing commit dbe0da4572c2ac22f26a7afd722349a5680a9e47 Author: Kyle Schouviller Date: Mon Oct 10 22:32:35 2022 -0700 Adding node-based invocation apps commit 8f2a2ffc083366de74d7dae471b50b6f98a7c5f8 Author: Damian at mba Date: Mon Oct 10 19:03:18 2022 +0200 fix merge issues commit 73118dee2a8f4891700756e014caf1c9ca629267 Merge: fd00844 12413b0 Author: Damian at mba Date: Mon Oct 10 12:42:48 2022 +0200 Merge remote-tracking branch 'upstream/development' into fix-prompts commit fd0084413541013c2cf71e006af0392719bef53d Author: Damian at mba Date: Mon Oct 10 12:39:38 2022 +0200 wip prompt parsing commit 0be9363db9307859d2b65cffc6af01f57d7873a4 Author: Damian at mba Date: Mon Oct 10 03:20:06 2022 +0200 better +/- attention parsing commit 5383f691874a58ab01cda1e4fac6cf330146526a Author: Damian at mba Date: Mon Oct 10 02:27:47 2022 +0200 prompt parser seems to work commit 591d098a33ce35462428d8c169501d8ed73615ab Author: Damian at mba Date: Sun Oct 9 20:25:37 2022 +0200 supports weighting unconditioning, cross-attention with | commit 7a7220563aa05a2980235b5b908362f66b728309 Author: Damian at mba Date: Sun Oct 9 18:15:56 2022 +0200 i think cross attention might be working? commit 951ed391e7126bff228c18b2db304ad28d59644a Author: Damian at mba Date: Sun Oct 9 16:04:54 2022 +0200 weighted CFG denoiser working with a single item commit ee532a0c2827368c9e45a6a5f3975666402873da Author: Damian at mba Date: Sun Oct 9 06:33:40 2022 +0200 wip probably doesn't work or compile commit 14654bcbd207b9ca28a6cbd37dbd967d699b062d Author: Damian at mba Date: Fri Oct 7 18:11:48 2022 +0200 use tan() to calculate embedding weight for <1 attentions commit 1a8e76b31aa5abf5150419ebf3b29d4658d07f2b Author: Damian at mba Date: Fri Oct 7 16:14:54 2022 +0200 fix bad math.max reference commit f697ff896875876ccaa1e5527405bdaa7ed27cde Author: Damian at mba Date: Fri Oct 7 15:55:57 2022 +0200 respect http[s]x protocol when making socket.io middleware commit 41d3dd4eeae8d4efb05dfb44fc6d8aac5dc468ab Author: Damian at mba Date: Fri Oct 7 13:29:54 2022 +0200 fractional weighting works, by blending with prompts excluding the word commit 087fb6dfb3e8f5e84de8c911f75faa3e3fa3553c Author: Damian at mba Date: Fri Oct 7 10:52:03 2022 +0200 wip doing weights <1 by averaging with conditioning absent the lower-weighted fragment commit 3c49e3f3ec7c18dc60f3e18ed2f7f0d97aad3a47 Author: Damian at mba Date: Fri Oct 7 10:36:15 2022 +0200 notate CFGDenoiser, perhaps commit d2bcf1bb522026ebf209ad0103f6b370383e5070 Author: Damian at mba Date: Thu Oct 6 05:04:47 2022 +0200 hack blending syntax to test attention weighting more extensively commit 94904ef2cf917f74ec23ef7a570e12ff8255b048 Author: Damian at mba Date: Thu Oct 6 04:56:37 2022 +0200 conditioning works, apparently commit 7c6663ddd70f665fd1308b6dd74f92ca393a8df5 Author: Damian at mba Date: Thu Oct 6 02:20:24 2022 +0200 attention weighting, definitely works in positive direction commit 5856d453a9b020bc1a28ff643ae1f58c12c9be73 Author: Damian at mba Date: Tue Oct 4 19:02:14 2022 +0200 wip bubbling weights down commit a2ed14fd9b7d3cb36b6c5348018b364c76d1e892 Author: Damian at mba Date: Tue Oct 4 17:35:39 2022 +0200 bring in changes from PC --- backend/server.py | 2 +- configs/stable-diffusion/v1-inference.yaml | 2 +- ldm/generate.py | 6 +- ldm/invoke/conditioning.py | 82 +++-- ldm/invoke/prompt_parser.py | 331 +++++++++++++++++++++ ldm/models/diffusion/ddim.py | 18 +- ldm/models/diffusion/ddpm.py | 4 +- ldm/models/diffusion/ksampler.py | 29 +- ldm/models/diffusion/sampler.py | 53 ++++ ldm/modules/encoders/modules.py | 203 +++++++++++++ tests/test_prompt_parser.py | 136 +++++++++ 11 files changed, 823 insertions(+), 43 deletions(-) create mode 100644 ldm/invoke/prompt_parser.py create mode 100644 tests/test_prompt_parser.py diff --git a/backend/server.py b/backend/server.py index 7b8a8a5a69..f14c141e12 100644 --- a/backend/server.py +++ b/backend/server.py @@ -527,7 +527,7 @@ def parameters_to_generated_image_metadata(parameters): rfc_dict["sampler"] = parameters["sampler_name"] # display weighted subprompts (liable to change) - subprompts = split_weighted_subprompts(parameters["prompt"]) + subprompts = split_weighted_subprompts(parameters["prompt"], skip_normalize=True) subprompts = [{"prompt": x[0], "weight": x[1]} for x in subprompts] rfc_dict["prompt"] = subprompts diff --git a/configs/stable-diffusion/v1-inference.yaml b/configs/stable-diffusion/v1-inference.yaml index 9c773077b6..baf91f6e26 100644 --- a/configs/stable-diffusion/v1-inference.yaml +++ b/configs/stable-diffusion/v1-inference.yaml @@ -76,4 +76,4 @@ model: target: torch.nn.Identity cond_stage_config: - target: ldm.modules.encoders.modules.FrozenCLIPEmbedder + target: ldm.modules.encoders.modules.WeightedFrozenCLIPEmbedder diff --git a/ldm/generate.py b/ldm/generate.py index 7fb68dec0a..965d37a240 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -438,7 +438,7 @@ class Generate: sampler=self.sampler, steps=steps, cfg_scale=cfg_scale, - conditioning=(uc, c), + conditioning=(uc, c), # here change to arrays ddim_eta=ddim_eta, image_callback=image_callback, # called after the final image is generated step_callback=step_callback, # called after each intermediate image is generated @@ -477,6 +477,10 @@ class Generate: print('**Interrupted** Partial results will be returned.') else: raise KeyboardInterrupt + # brute-force fallback + except Exception as e: + print(traceback.format_exc(), file=sys.stderr) + print('>> Could not generate image.') toc = time.time() print('>> Usage stats:') diff --git a/ldm/invoke/conditioning.py b/ldm/invoke/conditioning.py index fedd965a2c..9b67d5040d 100644 --- a/ldm/invoke/conditioning.py +++ b/ldm/invoke/conditioning.py @@ -12,42 +12,76 @@ log_tokenization() print out colour-coded tokens and warn if trunca import re import torch -def get_uc_and_c(prompt, model, log_tokens=False, skip_normalize=False): +from .prompt_parser import PromptParser, Fragment, Attention, Blend, Conjunction, FlattenedPrompt +from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder + + +def get_uc_and_c(prompt_string_uncleaned, model, log_tokens=False, skip_normalize=False): + # Extract Unconditioned Words From Prompt unconditioned_words = '' unconditional_regex = r'\[(.*?)\]' - unconditionals = re.findall(unconditional_regex, prompt) + unconditionals = re.findall(unconditional_regex, prompt_string_uncleaned) if len(unconditionals) > 0: unconditioned_words = ' '.join(unconditionals) # Remove Unconditioned Words From Prompt unconditional_regex_compile = re.compile(unconditional_regex) - clean_prompt = unconditional_regex_compile.sub(' ', prompt) - prompt = re.sub(' +', ' ', clean_prompt) + clean_prompt = unconditional_regex_compile.sub(' ', prompt_string_uncleaned) + prompt_string_cleaned = re.sub(' +', ' ', clean_prompt) + else: + prompt_string_cleaned = prompt_string_uncleaned - uc = model.get_learned_conditioning([unconditioned_words]) + pp = PromptParser() - # get weighted sub-prompts - weighted_subprompts = split_weighted_subprompts( - prompt, skip_normalize - ) + def build_conditioning_list(prompt_string:str): + parsed_conjunction: Conjunction = pp.parse(prompt_string) + print(f"parsed '{prompt_string}' to {parsed_conjunction}") + assert (type(parsed_conjunction) is Conjunction) + + conditioning_list = [] + def make_embeddings_for_flattened_prompt(flattened_prompt: FlattenedPrompt): + if type(flattened_prompt) is not FlattenedPrompt: + raise f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead" + fragments = [x[0] for x in flattened_prompt.children] + attention_weights = [x[1] for x in flattened_prompt.children] + print(fragments, attention_weights) + return model.get_learned_conditioning([fragments], attention_weights=[attention_weights]) + + for part,weight in zip(parsed_conjunction.prompts, parsed_conjunction.weights): + if type(part) is Blend: + blend:Blend = part + embeddings_to_blend = None + for flattened_prompt in blend.prompts: + this_embedding = make_embeddings_for_flattened_prompt(flattened_prompt) + embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat((embeddings_to_blend, this_embedding)) + blended_embeddings = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0), blend.weights, normalize=blend.normalize_weights) + conditioning_list.append((blended_embeddings, weight)) + else: + flattened_prompt: FlattenedPrompt = part + embeddings = make_embeddings_for_flattened_prompt(flattened_prompt) + conditioning_list.append((embeddings, weight)) + + return conditioning_list + + positive_conditioning_list = build_conditioning_list(prompt_string_cleaned) + negative_conditioning_list = build_conditioning_list(unconditioned_words) + + if len(negative_conditioning_list) == 0: + negative_conditioning = model.get_learned_conditioning([['']], attention_weights=[[1]]) + else: + if len(negative_conditioning_list)>1: + print("cannot do conjunctions on unconditioning for now") + negative_conditioning = negative_conditioning_list[0][0] + + #positive_conditioning_list.append((get_blend_prompts_and_weights(prompt), this_weight)) + #print("got empty_conditionining with shape", empty_conditioning.shape, "c[0][0] with shape", positive_conditioning[0][0].shape) + + # "unconditioned" means "the conditioning tensor is empty" + uc = negative_conditioning + c = positive_conditioning_list - if len(weighted_subprompts) > 1: - # i dont know if this is correct.. but it works - c = torch.zeros_like(uc) - # normalize each "sub prompt" and add it - for subprompt, weight in weighted_subprompts: - log_tokenization(subprompt, model, log_tokens, weight) - c = torch.add( - c, - model.get_learned_conditioning([subprompt]), - alpha=weight, - ) - else: # just standard 1 prompt - log_tokenization(prompt, model, log_tokens, 1) - c = model.get_learned_conditioning([prompt]) - uc = model.get_learned_conditioning([unconditioned_words]) return (uc, c) def split_weighted_subprompts(text, skip_normalize=False)->list: diff --git a/ldm/invoke/prompt_parser.py b/ldm/invoke/prompt_parser.py new file mode 100644 index 0000000000..c976918291 --- /dev/null +++ b/ldm/invoke/prompt_parser.py @@ -0,0 +1,331 @@ +import pyparsing +import pyparsing as pp +from pyparsing import original_text_for + + +class Prompt(): + + def __init__(self, parts: list): + for c in parts: + allowed_types = [Fragment, Attention, CFGScale] + if type(c) not in allowed_types: + raise PromptParser.ParsingException(f"Prompt cannot contain {type(c)}, only {allowed_types} are allowed") + self.children = parts + def __repr__(self): + return f"Prompt:{self.children}" + def __eq__(self, other): + return type(other) is Prompt and other.children == self.children + +class FlattenedPrompt(): + def __init__(self, parts: list): + # verify type correctness + for c in parts: + if type(c) is not tuple: + raise PromptParser.ParsingException( + f"FlattenedPrompt cannot contain {type(c)}, only ('text', weight) tuples are allowed") + text = c[0] + weight = c[1] + if type(text) is not str: + raise PromptParser.ParsingException(f"FlattenedPrompt cannot contain {type(c)}, only ('text', weight) tuples are allowed") + if type(weight) is not float and type(weight) is not int: + raise PromptParser.ParsingException( + f"FlattenedPrompt cannot contain {type(c)}, only ('text', weight) tuples are allowed") + # all looks good + self.children = parts + + def __repr__(self): + return f"FlattenedPrompt:{self.children}" + def __eq__(self, other): + return type(other) is FlattenedPrompt and other.children == self.children + + +class Attention(): + + def __init__(self, weight: float, children: list): + self.weight = weight + self.children = children + #print(f"A: requested attention '{children}' to {weight}") + + def __repr__(self): + return f"Attention:'{self.children}' @ {self.weight}" + def __eq__(self, other): + return type(other) is Attention and other.weight == self.weight and other.fragment == self.fragment + + +class CFGScale(): + def __init__(self, scale_factor: float, fragment: str): + self.fragment = fragment + self.scale_factor = scale_factor + #print(f"S: requested CFGScale '{fragment}' x {scale_factor}") + + def __repr__(self): + return f"CFGScale:'{self.fragment}' x {self.scale_factor}" + def __eq__(self, other): + return type(other) is CFGScale and other.scale_factor == self.scale_factor and other.fragment == self.fragment + + + +class Fragment(): + def __init__(self, text: str): + assert(type(text) is str) + self.text = text + + def __repr__(self): + return "Fragment:'"+self.text+"'" + def __eq__(self, other): + return type(other) is Fragment and other.text == self.text + +class Conjunction(): + def __init__(self, prompts: list, weights: list = None): + # force everything to be a Prompt + #print("making conjunction with", parts) + self.prompts = [x if (type(x) is Prompt or type(x) is Blend or type(x) is FlattenedPrompt) + else Prompt(x) for x in prompts] + self.weights = [1.0]*len(self.prompts) if weights is None else list(weights) + if len(self.weights) != len(self.prompts): + raise PromptParser.ParsingException(f"while parsing Conjunction: mismatched parts/weights counts {prompts}, {weights}") + self.type = 'AND' + + def __repr__(self): + return f"Conjunction:{self.prompts} | weights {self.weights}" + def __eq__(self, other): + return type(other) is Conjunction \ + and other.prompts == self.prompts \ + and other.weights == self.weights + + +class Blend(): + def __init__(self, prompts: list, weights: list[float], normalize_weights: bool=True): + #print("making Blend with prompts", prompts, "and weights", weights) + if len(prompts) != len(weights): + raise PromptParser.ParsingException(f"while parsing Blend: mismatched prompts/weights counts {prompts}, {weights}") + for c in prompts: + if type(c) is not Prompt and type(c) is not FlattenedPrompt: + raise(PromptParser.ParsingException(f"{type(c)} cannot be added to a Blend, only Prompts or FlattenedPrompts")) + # upcast all lists to Prompt objects + self.prompts = [x if (type(x) is Prompt or type(x) is FlattenedPrompt) + else Prompt(x) for x in prompts] + self.prompts = prompts + self.weights = weights + self.normalize_weights = normalize_weights + + def __repr__(self): + return f"Blend:{self.prompts} | weights {self.weights}" + def __eq__(self, other): + return other.__repr__() == self.__repr__() + + +class PromptParser(): + + class ParsingException(Exception): + pass + + def __init__(self, attention_plus_base=1.1, attention_minus_base=0.9): + + self.attention_plus_base = attention_plus_base + self.attention_minus_base = attention_minus_base + + self.root = self.build_parser_logic() + + + def parse(self, prompt: str) -> [list]: + ''' + :param prompt: The prompt string to parse + :return: a tuple + ''' + #print(f"!!parsing '{prompt}'") + + if len(prompt.strip()) == 0: + return Conjunction(prompts=[FlattenedPrompt([('', 1.0)])], weights=[1.0]) + + root = self.root.parse_string(prompt) + #print(f"'{prompt}' parsed to root", root) + #fused = fuse_fragments(parts) + #print("fused to", fused) + + return self.flatten(root[0]) + + def flatten(self, root: Conjunction): + + def fuse_fragments(items): + # print("fusing fragments in ", items) + result = [] + for x in items: + last_weight = result[-1][1] if len(result) > 0 else None + this_text = x[0] + this_weight = x[1] + if last_weight is not None and last_weight == this_weight: + last_text = result[-1][0] + result[-1] = (last_text + ' ' + this_text, last_weight) + else: + result.append(x) + return result + + def flatten_internal(node, weight_scale, results, prefix): + #print(prefix + "flattening", node, "...") + if type(node) is pp.ParseResults: + for x in node: + results = flatten_internal(x, weight_scale, results, prefix+'pr') + #print(prefix, " ParseResults expanded, results is now", results) + elif type(node) is Fragment: + results.append((node.text, float(weight_scale))) + elif type(node) is Attention: + #if node.weight < 1: + # todo: inject a blend when flattening attention with weight <1" + for c in node.children: + results = flatten_internal(c, weight_scale*node.weight, results, prefix+' ') + elif type(node) is Blend: + flattened_subprompts = [] + #print(" flattening blend with prompts", node.prompts, "weights", node.weights) + for prompt in node.prompts: + # prompt is a list + flattened_subprompts = flatten_internal(prompt, weight_scale, flattened_subprompts, prefix+'B ') + results += [Blend(prompts=flattened_subprompts, weights=node.weights)] + elif type(node) is Prompt: + #print(prefix + "about to flatten Prompt with children", node.children) + flattened_prompt = [] + for child in node.children: + flattened_prompt = flatten_internal(child, weight_scale, flattened_prompt, prefix+'P ') + results += [FlattenedPrompt(parts=fuse_fragments(flattened_prompt))] + #print(prefix + "after flattening Prompt, results is", results) + else: + raise PromptParser.ParsingException(f"unhandled node type {type(node)} when flattening {node}") + #print(prefix + "-> after flattening", type(node), "results is", results) + return results + + #print("flattening", root) + + flattened_parts = [] + for part in root.prompts: + flattened_parts += flatten_internal(part, 1.0, [], ' C| ') + weights = root.weights + return Conjunction(flattened_parts, weights) + + + + def build_parser_logic(self): + + lparen = pp.Literal("(").suppress() + rparen = pp.Literal(")").suppress() + # accepts int or float notation, always maps to float + number = pyparsing.pyparsing_common.real | pp.Word(pp.nums).set_parse_action(pp.token_map(float)) + SPACE_CHARS = ' \t\n' + + prompt_part = pp.Forward() + word = pp.Forward() + + def make_fragment(x): + #print("### making fragment for", x) + if type(x) is str: + return Fragment(x) + elif type(x) is pp.ParseResults or type(x) is list: + return Fragment(' '.join([s for s in x])) + else: + raise PromptParser.ParsingException("Cannot make fragment from " + str(x)) + + # attention control of the form +(phrase) / -(phrase) / (phrase) + # phrase can be multiple words, can have multiple +/- signs to increase the effect or type a floating point or integer weight + attention = pp.Forward() + attention_head = (number | pp.Word('+') | pp.Word('-'))\ + .set_name("attention_head")\ + .set_debug(False) + fragment_inside_attention = pp.CharsNotIn(SPACE_CHARS+'()')\ + .set_parse_action(make_fragment)\ + .set_name("fragment_inside_attention")\ + .set_debug(False) + attention_with_parens = pp.Forward() + attention_with_parens_body = pp.nested_expr(content=pp.delimited_list((attention_with_parens | fragment_inside_attention), delim=SPACE_CHARS)) + attention_with_parens << (attention_head + attention_with_parens_body) + + def make_attention(x): + # print("making Attention from parsing with args", x0, x1) + weight = 1 + # number(str) + if type(x[0]) is float or type(x[0]) is int: + weight = float(x[0]) + # +(str) or -(str) or +str or -str + elif type(x[0]) is str: + base = self.attention_plus_base if x[0][0] == '+' else self.attention_minus_base + weight = pow(base, len(x[0])) + # print("Making attention with children of type", [str(type(x)) for x in x1]) + return Attention(weight=weight, children=x[1]) + + attention_with_parens.set_parse_action(make_attention)\ + .set_name("attention_with_parens")\ + .set_debug(False) + + # attention control of the form ++word --word (no parens) + attention_without_parens = ( + (pp.Word('+') | pp.Word('-')) + + pp.CharsNotIn(SPACE_CHARS+'()').set_parse_action(lambda x: [[make_fragment(x)]]) + )\ + .set_name("attention_without_parens")\ + .set_debug(False) + attention_without_parens.set_parse_action(make_attention) + + attention << (attention_with_parens | attention_without_parens)\ + .set_name("attention")\ + .set_debug(False) + + # fragments of text with no attention control + word << pp.Word(pp.printables).set_parse_action(lambda x: Fragment(' '.join([s for s in x]))) + word.set_name("word") + word.set_debug(False) + prompt_part << (attention | word) + prompt_part.set_debug(False) + prompt_part.set_name("prompt_part") + + # root prompt definition + prompt = pp.Group(pp.OneOrMore(prompt_part))\ + .set_parse_action(lambda x: Prompt(x[0])) + + # weighted blend of prompts + # ("promptA", "promptB").blend(a, b) where "promptA" and "promptB" are valid prompts and a and b are float or + # int weights. + # can specify more terms eg ("promptA", "promptB", "promptC").blend(a,b,c) + + def make_prompt_from_quoted_string(x): + #print(' got quoted prompt', x) + + x_unquoted = x[0][1:-1] + if len(x_unquoted.strip()) == 0: + # print(' b : just an empty string') + return Prompt([Fragment('')]) + # print(' b parsing ', c_unquoted) + x_parsed = prompt.parse_string(x_unquoted) + #print(" quoted prompt was parsed to", type(x_parsed),":", x_parsed) + return x_parsed[0] + + quoted_prompt = pp.dbl_quoted_string.set_parse_action(make_prompt_from_quoted_string) + quoted_prompt.set_name('quoted_prompt') + + blend_terms = pp.delimited_list(quoted_prompt).set_name('blend_terms') + blend_weights = pp.delimited_list(number).set_name('blend_weights') + blend = pp.Group(lparen + pp.Group(blend_terms) + rparen + + pp.Literal(".blend").suppress() + + lparen + pp.Group(blend_weights) + rparen).set_name('blend') + blend.set_debug(False) + + + blend.set_parse_action(lambda x: Blend(prompts=x[0][0], weights=x[0][1])) + + conjunction_terms = blend_terms.copy().set_name('conjunction_terms') + conjunction_weights = blend_weights.copy().set_name('conjunction_weights') + conjunction_with_parens_and_quotes = pp.Group(lparen + pp.Group(conjunction_terms) + rparen + + pp.Literal(".and").suppress() + + lparen + pp.Optional(pp.Group(conjunction_weights)) + rparen).set_name('conjunction') + def make_conjunction(x): + parts_raw = x[0][0] + weights = x[0][1] if len(x[0])>1 else [1.0]*len(parts_raw) + parts = [part for part in parts_raw] + return Conjunction(parts, weights) + conjunction_with_parens_and_quotes.set_parse_action(make_conjunction) + + implicit_conjunction = pp.OneOrMore(blend | prompt) + implicit_conjunction.set_parse_action(lambda x: Conjunction(x)) + + conjunction = conjunction_with_parens_and_quotes | implicit_conjunction + conjunction.set_debug(False) + + # top-level is a conjunction of one or more blends or prompts + return conjunction diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py index f5dada8627..5120d92c48 100644 --- a/ldm/models/diffusion/ddim.py +++ b/ldm/models/diffusion/ddim.py @@ -34,23 +34,21 @@ class DDIMSampler(Sampler): b, *_, device = *x.shape, x.device if ( - unconditional_conditioning is None - or unconditional_guidance_scale == 1.0 + (unconditional_conditioning is None + or unconditional_guidance_scale == 1.0) + and c is not list ): e_t = self.model.apply_model(x, t, c) else: - x_in = torch.cat([x] * 2) - t_in = torch.cat([t] * 2) - c_in = torch.cat([unconditional_conditioning, c]) - e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) - e_t = e_t_uncond + unconditional_guidance_scale * ( - e_t - e_t_uncond - ) + e_t = self.apply_weighted_conditioning_list(x, t, self.model.apply_model, unconditional_conditioning, c, unconditional_guidance_scale) if score_corrector is not None: assert self.model.parameterization == 'eps' + if c is list and len(c)>1: + print("warning: ddim score modifier currently ignores all but the first part of the prompt conjunction, this is probably wrong") + corrector_c = [c[0][0] if c is list else c] e_t = score_corrector.modify_score( - self.model, e_t, x, t, c, **corrector_kwargs + self.model, e_t, x, t, corrector_c, **corrector_kwargs ) alphas = ( diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py index 4b62b5e393..0f55786323 100644 --- a/ldm/models/diffusion/ddpm.py +++ b/ldm/models/diffusion/ddpm.py @@ -820,13 +820,13 @@ class LatentDiffusion(DDPM): ) return self.scale_factor * z - def get_learned_conditioning(self, c): + def get_learned_conditioning(self, c, attention_weights=None): if self.cond_stage_forward is None: if hasattr(self.cond_stage_model, 'encode') and callable( self.cond_stage_model.encode ): c = self.cond_stage_model.encode( - c, embedding_manager=self.embedding_manager + c, embedding_manager=self.embedding_manager, attention_weights=attention_weights ) if isinstance(c, DiagonalGaussianDistribution): c = c.mode() diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index ac0615b30c..4d37c8cf9b 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -38,7 +38,8 @@ class CFGDenoiser(nn.Module): x_in = torch.cat([x] * 2) sigma_in = torch.cat([sigma] * 2) cond_in = torch.cat([uncond, cond]) - uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) + unconditioned_x, conditioned_x = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) + if self.warmup < self.warmup_max: thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max)) self.warmup += 1 @@ -46,7 +47,28 @@ class CFGDenoiser(nn.Module): thresh = self.threshold if thresh > self.threshold: thresh = self.threshold - return cfg_apply_threshold(uncond + (cond - uncond) * cond_scale, thresh) + + # damian0815 thinking out loud notes: + # b + (a - b)*scale + # starting at the output that emerges applying the negative prompt (by default ''), + # (-> this is why the unconditioning feels like hammer) + # move toward the positive prompt by an amount controlled by cond_scale. + return cfg_apply_threshold(unconditioned_x + (conditioned_x - unconditioned_x) * cond_scale, thresh) + + +class ProgrammableCFGDenoiser(CFGDenoiser): + def forward(self, x, sigma, uncond, cond, cond_scale): + forward_lambda = lambda x, t, c: self.inner_model(x, t, cond=c) + x_new = Sampler.apply_weighted_conditioning_list(x, sigma, forward_lambda, uncond, cond, cond_scale) + + if self.warmup < self.warmup_max: + thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max)) + self.warmup += 1 + else: + thresh = self.threshold + if thresh > self.threshold: + thresh = self.threshold + return cfg_apply_threshold(x_new, threshold=thresh) class KSampler(Sampler): @@ -181,7 +203,6 @@ class KSampler(Sampler): ) # sigmas are set up in make_schedule - we take the last steps items - total_steps = len(self.sigmas) sigmas = self.sigmas[-S-1:] # x_T is variation noise. When an init image is provided (in x0) we need to add @@ -194,7 +215,7 @@ class KSampler(Sampler): else: x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] - model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10)) + model_wrap_cfg = ProgrammableCFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10)) extra_args = { 'cond': conditioning, 'uncond': unconditional_conditioning, diff --git a/ldm/models/diffusion/sampler.py b/ldm/models/diffusion/sampler.py index ff705513f8..42704f1175 100644 --- a/ldm/models/diffusion/sampler.py +++ b/ldm/models/diffusion/sampler.py @@ -4,6 +4,8 @@ ldm.models.diffusion.sampler Base class for ldm.models.diffusion.ddim, ldm.models.diffusion.ksampler, etc ''' +from math import ceil + import torch import numpy as np from tqdm import tqdm @@ -411,3 +413,54 @@ class Sampler(object): return self.model.inner_model.q_sample(x0,ts) ''' return self.model.q_sample(x0,ts) + + + @classmethod + def apply_weighted_conditioning_list(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale): + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) # aka sigmas + + deltas = None + uncond_latents = None + weighted_cond_list = c_or_weighted_c_list if type(c_or_weighted_c_list) is list else [(c_or_weighted_c_list, 1)] + + # below is fugly omg + num_actual_conditionings = len(c_or_weighted_c_list) + conditionings = [uc] + [c for c,weight in weighted_cond_list] + weights = [1] + [weight for c,weight in weighted_cond_list] + chunk_count = ceil(len(conditionings)/2) + assert(len(conditionings)>=2, "need at least one uncond and one cond") + deltas = None + for chunk_index in range(chunk_count): + offset = chunk_index*2 + chunk_size = min(2, len(conditionings)-offset) + + if chunk_size == 1: + c_in = conditionings[offset] + latents_a = forward_func(x_in[:-1], t_in[:-1], c_in) + latents_b = None + else: + c_in = torch.cat(conditionings[offset:offset+2]) + latents_a, latents_b = forward_func(x_in, t_in, c_in).chunk(2) + + # first chunk is guaranteed to be 2 entries: uncond_latents + first conditioining + if chunk_index == 0: + uncond_latents = latents_a + deltas = latents_b - uncond_latents + else: + deltas = torch.cat((deltas, latents_a - uncond_latents)) + if latents_b is not None: + deltas = torch.cat((deltas, latents_b - uncond_latents)) + + # merge the weighted deltas together into a single merged delta + per_delta_weights = torch.tensor(weights[1:], dtype=deltas.dtype, device=deltas.device) + normalize = False + if normalize: + per_delta_weights /= torch.sum(per_delta_weights) + reshaped_weights = per_delta_weights.reshape(per_delta_weights.shape + (1, 1, 1)) + deltas_merged = torch.sum(deltas * reshaped_weights, dim=0, keepdim=True) + + # old_return_value = super().forward(x, sigma, uncond, cond, cond_scale) + # assert(0 == len(torch.nonzero(old_return_value - (uncond_latents + deltas_merged * cond_scale)))) + + return uncond_latents + deltas_merged * global_guidance_scale diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index 426fccced3..857a8a8e3e 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -1,3 +1,5 @@ +import math + import torch import torch.nn as nn from functools import partial @@ -454,6 +456,207 @@ class FrozenCLIPEmbedder(AbstractEncoder): def encode(self, text, **kwargs): return self(text, **kwargs) +class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder): + + attention_weights_key = "attention_weights" + + def build_token_list_fragment(self, fragment: str, weight: float) -> (torch.Tensor, torch.Tensor): + batch_encoding = self.tokenizer( + fragment, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding='none', + return_tensors='pt', + ) + return batch_encoding, torch.ones_like(batch_encoding) * weight + + + def forward(self, text: list, **kwargs): + ''' + + :param text: A batch of prompt strings, or, a batch of lists of fragments of prompt strings to which different + weights shall be applied. + :param kwargs: If the keyword arg "attention_weights" is passed, it shall contain a batch of lists of weights + for the prompt fragments. In this case text must contain batches of lists of prompt fragments. + :return: A tensor of shape (B, 77, 768) containing weighted embeddings + ''' + if self.attention_weights_key not in kwargs: + # fallback to base class implementation + return super().forward(text, **kwargs) + + attention_weights = kwargs[self.attention_weights_key] + # self.transformer doesn't like receiving "attention_weights" as an argument + kwargs.pop(self.attention_weights_key) + + batch_z = None + for fragments, weights in zip(text, attention_weights): + + # First, weight tokens in individual fragments by scaling the feature vectors as requested (effectively + # applying a multiplier to the CFG scale on a per-token basis). + # For tokens weighted<1, intuitively we want SD to become not merely *less* interested in the concept + # captured by the fragment but actually *dis*interested in it (a 0.01 interest in "red" is still an active + # interest, however small, in redness; what the user probably intends when they attach the number 0.01 to + # "red" is to tell SD that it should almost completely *ignore* redness). + # To do this, the embedding is lerped away from base_embedding in the direction of an embedding for a prompt + # string from which the low-weighted fragment has been simply removed. The closer the weight is to zero, the + # closer the resulting embedding is to an embedding for a prompt that simply lacks this fragment. + + # handle weights >=1 + tokens, per_token_weights = self.get_tokens_and_weights(fragments, weights) + base_embedding = self.build_weighted_embedding_tensor(tokens, per_token_weights, **kwargs) + + # this is our starting point + embeddings = base_embedding.unsqueeze(0) + per_embedding_weights = [1.0] + + # now handle weights <1 + # Do this by building extra embeddings tensors that lack the words being <1 weighted. These will be lerped + # with the embeddings tensors that have the words, such that if the weight of a word is 0.5, the resulting + # embedding will be exactly half-way between the unweighted prompt and the prompt with the <1 weighted words + # removed. + # eg for "mountain:1 man:0.5", intuitively the "man" should be "half-gone". therefore, append an embedding + # for "mountain" (i.e. without "man") to the already-produced embedding for "mountain man", and weight it + # such that the resulting lerped embedding is exactly half-way between "mountain man" and "mountain". + for index, fragment_weight in enumerate(weights): + if fragment_weight < 1: + fragments_without_this = fragments[:index] + fragments[index+1:] + weights_without_this = weights[:index] + weights[index+1:] + tokens, per_token_weights = self.get_tokens_and_weights(fragments_without_this, weights_without_this) + embedding_without_this = self.build_weighted_embedding_tensor(tokens, per_token_weights, **kwargs) + + embeddings = torch.cat((embeddings, embedding_without_this.unsqueeze(0)), dim=1) + # weight of the embedding *without* this fragment gets *stronger* as its weight approaches 0 + # if fragment_weight = 0, basically we want embedding_without_this to completely overwhelm base_embedding + # therefore: + # fragment_weight = 1: we are at base_z => lerp weight 0 + # fragment_weight = 0.5: we are halfway between base_z and here => lerp weight 1 + # fragment_weight = 0: we're now entirely overriding base_z ==> lerp weight inf + # so let's use tan(), because: + # tan is 0.0 at 0, + # 1.0 at PI/4, and + # inf at PI/2 + # -> tan((1-weight)*PI/2) should give us ideal lerp weights + epsilon = 1e-9 + fragment_weight = max(epsilon, fragment_weight) # inf is bad + embedding_lerp_weight = math.tan((1.0 - fragment_weight) * math.pi / 2) + # todo handle negative weight? + + per_embedding_weights.append(embedding_lerp_weight) + + lerped_embeddings = self.apply_embedding_weights(embeddings, per_embedding_weights, normalize=True).squeeze(0) + + print(f"assembled tokens for '{fragments}' into tensor of shape {lerped_embeddings.shape}") + + # append to batch + batch_z = lerped_embeddings.unsqueeze(0) if batch_z is None else torch.cat((batch_z, lerped_embeddings.unsqueeze(0)), dim=1) + + # should have shape (B, 77, 768) + print(f"assembled all tokens into tensor of shape {batch_z.shape}") + + return batch_z + + @classmethod + def apply_embedding_weights(self, embeddings: torch.Tensor, per_embedding_weights: list[float], normalize:bool) -> torch.Tensor: + per_embedding_weights = torch.tensor(per_embedding_weights, dtype=embeddings.dtype, device=embeddings.device) + if normalize: + per_embedding_weights = per_embedding_weights / torch.sum(per_embedding_weights) + reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1, 1,)) + #reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1,1,)).expand(embeddings.shape) + return torch.sum(embeddings * reshaped_weights, dim=1) + # lerped embeddings has shape (77, 768) + + + def get_tokens_and_weights(self, fragments: list[str], weights: list[float]) -> (torch.Tensor, torch.Tensor): + ''' + + :param fragments: + :param weights: Per-fragment weights (CFG scaling). No need for these to be normalized. They will not be normalized here and that's fine. + :return: + ''' + # empty is meaningful + if len(fragments) == 0 and len(weights) == 0: + fragments = [''] + weights = [1] + item_encodings = self.tokenizer( + fragments, + truncation=True, + max_length=self.max_length, + return_overflowing_tokens=False, + padding='do_not_pad', + return_tensors=None, # just give me a list of ints + )['input_ids'] + all_tokens = [] + per_token_weights = [] + print("all fragments:", fragments, weights) + for index, fragment in enumerate(item_encodings): + weight = weights[index] + print("processing fragment", fragment, weight) + fragment_tokens = item_encodings[index] + print("fragment", fragment, "processed to", fragment_tokens) + # trim bos and eos markers before appending + all_tokens.extend(fragment_tokens[1:-1]) + per_token_weights.extend([weight] * (len(fragment_tokens) - 2)) + + if len(all_tokens) > self.max_length - 2: + print("prompt is too long and has been truncated") + all_tokens = all_tokens[:self.max_length - 2] + + # pad out to a 77-entry array: [eos_token, , eos_token, ..., eos_token] + # (77 = self.max_length) + pad_length = self.max_length - 1 - len(all_tokens) + all_tokens.insert(0, self.tokenizer.bos_token_id) + all_tokens.extend([self.tokenizer.eos_token_id] * pad_length) + per_token_weights.insert(0, 1) + per_token_weights.extend([1] * pad_length) + + all_tokens_tensor = torch.tensor(all_tokens, dtype=torch.long).to(self.device) + per_token_weights_tensor = torch.tensor(per_token_weights, dtype=torch.float32).to(self.device) + print(f"assembled all_tokens_tensor with shape {all_tokens_tensor.shape}") + return all_tokens_tensor, per_token_weights_tensor + + def build_weighted_embedding_tensor(self, tokens: torch.Tensor, per_token_weights: torch.Tensor, weight_delta_from_empty=True, **kwargs) -> torch.Tensor: + ''' + Build a tensor representing the passed-in tokens, each of which has a weight. + :param tokens: A tensor of shape (77) containing token ids (integers) + :param per_token_weights: A tensor of shape (77) containing weights (floats) + :param method: Whether to multiply the whole feature vector for each token or just its distance from an "empty" feature vector + :param kwargs: passed on to self.transformer() + :return: A tensor of shape (1, 77, 768) representing the requested weighted embeddings. + ''' + #print(f"building weighted embedding tensor for {tokens} with weights {per_token_weights}") + z = self.transformer(input_ids=tokens.unsqueeze(0), **kwargs) + batch_weights_expanded = per_token_weights.reshape(per_token_weights.shape + (1,)).expand(z.shape) + + if weight_delta_from_empty: + empty_tokens = self.tokenizer([''] * z.shape[0], + truncation=True, + max_length=self.max_length, + padding='max_length', + return_tensors='pt' + )['input_ids'].to(self.device) + empty_z = self.transformer(input_ids=empty_tokens, **kwargs) + z_delta_from_empty = z - empty_z + weighted_z = empty_z + (z_delta_from_empty * batch_weights_expanded) + + weighted_z_delta_from_empty = (weighted_z-empty_z) + print("weighted z has delta from empty with sum", weighted_z_delta_from_empty.sum().item(), "mean", weighted_z_delta_from_empty.mean().item() ) + + #print("using empty-delta method, first 5 rows:") + #print(weighted_z[:5]) + + return weighted_z + + else: + original_mean = z.mean() + z *= batch_weights_expanded + after_weighting_mean = z.mean() + # correct the mean. not sure if this is right but it's what the automatic1111 fork of SD does + mean_correction_factor = original_mean/after_weighting_mean + z *= mean_correction_factor + return z + class FrozenCLIPTextEmbedder(nn.Module): """ diff --git a/tests/test_prompt_parser.py b/tests/test_prompt_parser.py new file mode 100644 index 0000000000..207475d02e --- /dev/null +++ b/tests/test_prompt_parser.py @@ -0,0 +1,136 @@ +import unittest + +from ldm.invoke.prompt_parser import PromptParser, Blend, Conjunction, FlattenedPrompt + +def parse_prompt(prompt_string): + pp = PromptParser() + #print(f"parsing '{prompt_string}'") + parse_result = pp.parse(prompt_string) + #print(f"-> parsed '{prompt_string}' to {parse_result}") + return parse_result + +class PromptParserTestCase(unittest.TestCase): + + def test_empty(self): + self.assertEqual(Conjunction([FlattenedPrompt([('', 1)])]), parse_prompt('')) + + def test_basic(self): + self.assertEqual(Conjunction([FlattenedPrompt([('fire (flames)', 1)])]), parse_prompt("fire (flames)")) + self.assertEqual(Conjunction([FlattenedPrompt([("fire flames", 1)])]), parse_prompt("fire flames")) + self.assertEqual(Conjunction([FlattenedPrompt([("fire, flames", 1)])]), parse_prompt("fire, flames")) + self.assertEqual(Conjunction([FlattenedPrompt([("fire, flames , fire", 1)])]), parse_prompt("fire, flames , fire")) + + def test_attention(self): + self.assertEqual(Conjunction([FlattenedPrompt([('flames', 0.5)])]), parse_prompt("0.5(flames)")) + self.assertEqual(Conjunction([FlattenedPrompt([('fire flames', 0.5)])]), parse_prompt("0.5(fire flames)")) + self.assertEqual(Conjunction([FlattenedPrompt([('flames', 1.1)])]), parse_prompt("+(flames)")) + self.assertEqual(Conjunction([FlattenedPrompt([('flames', 0.9)])]), parse_prompt("-(flames)")) + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1), ('flames', 0.5)])]), parse_prompt("fire 0.5(flames)")) + self.assertEqual(Conjunction([FlattenedPrompt([('flames', pow(1.1, 2))])]), parse_prompt("++(flames)")) + self.assertEqual(Conjunction([FlattenedPrompt([('flames', pow(0.9, 2))])]), parse_prompt("--(flames)")) + self.assertEqual(Conjunction([FlattenedPrompt([('flowers', pow(0.9, 3)), ('flames', pow(1.1, 3))])]), parse_prompt("---(flowers) +++flames")) + self.assertEqual(Conjunction([FlattenedPrompt([('flowers', pow(0.9, 3)), ('flames', pow(1.1, 3))])]), parse_prompt("---(flowers) +++flames")) + self.assertEqual(Conjunction([FlattenedPrompt([('flowers', pow(0.9, 3)), ('flames+', pow(1.1, 3))])]), + parse_prompt("---(flowers) +++flames+")) + self.assertEqual(Conjunction([FlattenedPrompt([('pretty flowers', 1.1)])]), + parse_prompt("+(pretty flowers)")) + self.assertEqual(Conjunction([FlattenedPrompt([('pretty flowers', 1.1), (', the flames are too hot', 1)])]), + parse_prompt("+(pretty flowers), the flames are too hot")) + + def test_no_parens_attention_runon(self): + self.assertEqual(Conjunction([FlattenedPrompt([('fire', pow(1.1, 2)), ('flames', 1.0)])]), parse_prompt("++fire flames")) + self.assertEqual(Conjunction([FlattenedPrompt([('fire', pow(0.9, 2)), ('flames', 1.0)])]), parse_prompt("--fire flames")) + self.assertEqual(Conjunction([FlattenedPrompt([('flowers', 1.0), ('fire', pow(1.1, 2)), ('flames', 1.0)])]), parse_prompt("flowers ++fire flames")) + self.assertEqual(Conjunction([FlattenedPrompt([('flowers', 1.0), ('fire', pow(0.9, 2)), ('flames', 1.0)])]), parse_prompt("flowers --fire flames")) + + + def test_explicit_conjunction(self): + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])]), parse_prompt('("fire", "flames").and(1,1)')) + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])]), parse_prompt('("fire", "flames").and()')) + self.assertEqual( + Conjunction([FlattenedPrompt([('fire flames', 1.0)]), FlattenedPrompt([('mountain man', 1.0)])]), parse_prompt('("fire flames", "mountain man").and()')) + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 2.0)]), FlattenedPrompt([('flames', 0.9)])]), parse_prompt('("2.0(fire)", "-flames").and()')) + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)]), + FlattenedPrompt([('mountain man', 1.0)])]), parse_prompt('("fire", "flames", "mountain man").and()')) + + def test_conjunction_weights(self): + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])], weights=[2.0,1.0]), parse_prompt('("fire", "flames").and(2,1)')) + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])], weights=[1.0,2.0]), parse_prompt('("fire", "flames").and(1,2)')) + + with self.assertRaises(PromptParser.ParsingException): + parse_prompt('("fire", "flames").and(2)') + parse_prompt('("fire", "flames").and(2,1,2)') + + def test_complex_conjunction(self): + self.assertEqual(Conjunction([FlattenedPrompt([("mountain man", 1.0)]), FlattenedPrompt([("a person with a hat", 1.0), ("riding a bicycle", pow(1.1,2))])], weights=[0.5, 0.5]), + parse_prompt("(\"mountain man\", \"a person with a hat ++(riding a bicycle)\").and(0.5, 0.5)")) + + def test_badly_formed(self): + def make_untouched_prompt(prompt): + return Conjunction([FlattenedPrompt([(prompt, 1.0)])]) + + def assert_if_prompt_string_not_untouched(prompt): + self.assertEqual(make_untouched_prompt(prompt), parse_prompt(prompt)) + + assert_if_prompt_string_not_untouched('a test prompt') + assert_if_prompt_string_not_untouched('a badly (formed test prompt') + assert_if_prompt_string_not_untouched('a badly formed test+ prompt') + assert_if_prompt_string_not_untouched('a badly (formed test+ prompt') + assert_if_prompt_string_not_untouched('a badly (formed test+ )prompt') + assert_if_prompt_string_not_untouched('a badly (formed test+ )prompt') + assert_if_prompt_string_not_untouched('(((a badly (formed test+ )prompt') + assert_if_prompt_string_not_untouched('(a (ba)dly (f)ormed test+ prompt') + self.assertEqual(Conjunction([FlattenedPrompt([('(a (ba)dly (f)ormed test+', 1.0), ('prompt', 1.1)])]), + parse_prompt('(a (ba)dly (f)ormed test+ +prompt')) + self.assertEqual(Conjunction([Blend([FlattenedPrompt([('((a badly (formed test+', 1.0)])], weights=[1.0])]), + parse_prompt('("((a badly (formed test+ ").blend(1.0)')) + + def test_blend(self): + self.assertEqual(Conjunction( + [Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('fire flames', 1.0)])], [0.7, 0.3])]), + parse_prompt("(\"fire\", \"fire flames\").blend(0.7, 0.3)") + ) + self.assertEqual(Conjunction([Blend( + [FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('fire flames', 1.0)]), FlattenedPrompt([('hi', 1.0)])], + [0.7, 0.3, 1.0])]), + parse_prompt("(\"fire\", \"fire flames\", \"hi\").blend(0.7, 0.3, 1.0)") + ) + self.assertEqual(Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), + FlattenedPrompt([('fire flames', 1.0), ('hot', pow(1.1, 2))]), + FlattenedPrompt([('hi', 1.0)])], + weights=[0.7, 0.3, 1.0])]), + parse_prompt("(\"fire\", \"fire flames ++(hot)\", \"hi\").blend(0.7, 0.3, 1.0)") + ) + # blend a single entry is not a failure + self.assertEqual(Conjunction([Blend([FlattenedPrompt([('fire', 1.0)])], [0.7])]), + parse_prompt("(\"fire\").blend(0.7)") + ) + # blend with empty + self.assertEqual( + Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]), + parse_prompt("(\"fire\", \"\").blend(0.7, 1)") + ) + self.assertEqual( + Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]), + parse_prompt("(\"fire\", \" \").blend(0.7, 1)") + ) + self.assertEqual( + Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]), + parse_prompt("(\"fire\", \" \").blend(0.7, 1)") + ) + self.assertEqual( + Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([(',', 1.0)])], [0.7, 1.0])]), + parse_prompt("(\"fire\", \" , \").blend(0.7, 1)") + ) + + + def test_nested(self): + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0), ('flames', 2.0), ('trees', 3.0)])]), + parse_prompt('fire 2.0(flames 1.5(trees))')) + self.assertEqual(Conjunction([Blend(prompts=[FlattenedPrompt([('fire', 1.0), ('flames', 1.2100000000000002)]), + FlattenedPrompt([('mountain', 1.0), ('man', 2.0)])], + weights=[1.0, 1.0])]), + parse_prompt('("fire ++(flames)", "mountain 2(man)").blend(1,1)')) + +if __name__ == '__main__': + unittest.main() From 11d7e6b92f2deab1ad4998aad835920593a1e6d3 Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Sat, 15 Oct 2022 23:58:13 +0200 Subject: [PATCH 2/4] undo unwanted changes --- ldm/generate.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/ldm/generate.py b/ldm/generate.py index 965d37a240..7fb68dec0a 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -438,7 +438,7 @@ class Generate: sampler=self.sampler, steps=steps, cfg_scale=cfg_scale, - conditioning=(uc, c), # here change to arrays + conditioning=(uc, c), ddim_eta=ddim_eta, image_callback=image_callback, # called after the final image is generated step_callback=step_callback, # called after each intermediate image is generated @@ -477,10 +477,6 @@ class Generate: print('**Interrupted** Partial results will be returned.') else: raise KeyboardInterrupt - # brute-force fallback - except Exception as e: - print(traceback.format_exc(), file=sys.stderr) - print('>> Could not generate image.') toc = time.time() print('>> Usage stats:') From c6ae9f117634bbfa5d385e98319a36a7701f6ecb Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Sun, 16 Oct 2022 00:45:38 +0200 Subject: [PATCH 3/4] remove unnecessary assertion --- ldm/models/diffusion/sampler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ldm/models/diffusion/sampler.py b/ldm/models/diffusion/sampler.py index 42704f1175..417d1d4491 100644 --- a/ldm/models/diffusion/sampler.py +++ b/ldm/models/diffusion/sampler.py @@ -429,7 +429,6 @@ class Sampler(object): conditionings = [uc] + [c for c,weight in weighted_cond_list] weights = [1] + [weight for c,weight in weighted_cond_list] chunk_count = ceil(len(conditionings)/2) - assert(len(conditionings)>=2, "need at least one uncond and one cond") deltas = None for chunk_index in range(chunk_count): offset = chunk_index*2 From 61357e4e6eecc43b9d52859eb63b4db4e5ffafc1 Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Sun, 16 Oct 2022 01:53:44 +0200 Subject: [PATCH 4/4] be less verbose when assembling prompt --- ldm/invoke/conditioning.py | 13 +++++++------ ldm/modules/encoders/modules.py | 19 ++++++++++--------- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/ldm/invoke/conditioning.py b/ldm/invoke/conditioning.py index 9b67d5040d..e3190f6ed6 100644 --- a/ldm/invoke/conditioning.py +++ b/ldm/invoke/conditioning.py @@ -35,9 +35,10 @@ def get_uc_and_c(prompt_string_uncleaned, model, log_tokens=False, skip_normaliz pp = PromptParser() - def build_conditioning_list(prompt_string:str): + def build_conditioning_list(prompt_string:str, verbose:bool = False): parsed_conjunction: Conjunction = pp.parse(prompt_string) - print(f"parsed '{prompt_string}' to {parsed_conjunction}") + if verbose: + print(f"parsed '{prompt_string}' to {parsed_conjunction}") assert (type(parsed_conjunction) is Conjunction) conditioning_list = [] @@ -46,7 +47,7 @@ def get_uc_and_c(prompt_string_uncleaned, model, log_tokens=False, skip_normaliz raise f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead" fragments = [x[0] for x in flattened_prompt.children] attention_weights = [x[1] for x in flattened_prompt.children] - print(fragments, attention_weights) + #print(fragments, attention_weights) return model.get_learned_conditioning([fragments], attention_weights=[attention_weights]) for part,weight in zip(parsed_conjunction.prompts, parsed_conjunction.weights): @@ -65,14 +66,14 @@ def get_uc_and_c(prompt_string_uncleaned, model, log_tokens=False, skip_normaliz return conditioning_list - positive_conditioning_list = build_conditioning_list(prompt_string_cleaned) - negative_conditioning_list = build_conditioning_list(unconditioned_words) + positive_conditioning_list = build_conditioning_list(prompt_string_cleaned, verbose=True) + negative_conditioning_list = build_conditioning_list(unconditioned_words, verbose=(len(unconditioned_words)>0) ) if len(negative_conditioning_list) == 0: negative_conditioning = model.get_learned_conditioning([['']], attention_weights=[[1]]) else: if len(negative_conditioning_list)>1: - print("cannot do conjunctions on unconditioning for now") + print("cannot do conjunctions on unconditioning for now, everything except the first prompt will be ignored") negative_conditioning = negative_conditioning_list[0][0] #positive_conditioning_list.append((get_blend_prompts_and_weights(prompt), this_weight)) diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index 857a8a8e3e..fcd0363e80 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -547,13 +547,13 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder): lerped_embeddings = self.apply_embedding_weights(embeddings, per_embedding_weights, normalize=True).squeeze(0) - print(f"assembled tokens for '{fragments}' into tensor of shape {lerped_embeddings.shape}") + #print(f"assembled tokens for '{fragments}' into tensor of shape {lerped_embeddings.shape}") # append to batch batch_z = lerped_embeddings.unsqueeze(0) if batch_z is None else torch.cat((batch_z, lerped_embeddings.unsqueeze(0)), dim=1) # should have shape (B, 77, 768) - print(f"assembled all tokens into tensor of shape {batch_z.shape}") + #print(f"assembled all tokens into tensor of shape {batch_z.shape}") return batch_z @@ -589,18 +589,19 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder): )['input_ids'] all_tokens = [] per_token_weights = [] - print("all fragments:", fragments, weights) + #print("all fragments:", fragments, weights) for index, fragment in enumerate(item_encodings): weight = weights[index] - print("processing fragment", fragment, weight) + #print("processing fragment", fragment, weight) fragment_tokens = item_encodings[index] - print("fragment", fragment, "processed to", fragment_tokens) + #print("fragment", fragment, "processed to", fragment_tokens) # trim bos and eos markers before appending all_tokens.extend(fragment_tokens[1:-1]) per_token_weights.extend([weight] * (len(fragment_tokens) - 2)) - if len(all_tokens) > self.max_length - 2: - print("prompt is too long and has been truncated") + if (len(all_tokens) + 2) > self.max_length: + excess_token_count = (len(all_tokens) + 2) - self.max_length + print(f"prompt is {excess_token_count} token(s) too long and has been truncated") all_tokens = all_tokens[:self.max_length - 2] # pad out to a 77-entry array: [eos_token, , eos_token, ..., eos_token] @@ -613,7 +614,7 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder): all_tokens_tensor = torch.tensor(all_tokens, dtype=torch.long).to(self.device) per_token_weights_tensor = torch.tensor(per_token_weights, dtype=torch.float32).to(self.device) - print(f"assembled all_tokens_tensor with shape {all_tokens_tensor.shape}") + #print(f"assembled all_tokens_tensor with shape {all_tokens_tensor.shape}") return all_tokens_tensor, per_token_weights_tensor def build_weighted_embedding_tensor(self, tokens: torch.Tensor, per_token_weights: torch.Tensor, weight_delta_from_empty=True, **kwargs) -> torch.Tensor: @@ -641,7 +642,7 @@ class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder): weighted_z = empty_z + (z_delta_from_empty * batch_weights_expanded) weighted_z_delta_from_empty = (weighted_z-empty_z) - print("weighted z has delta from empty with sum", weighted_z_delta_from_empty.sum().item(), "mean", weighted_z_delta_from_empty.mean().item() ) + #print("weighted z has delta from empty with sum", weighted_z_delta_from_empty.sum().item(), "mean", weighted_z_delta_from_empty.mean().item() ) #print("using empty-delta method, first 5 rows:") #print(weighted_z[:5])