diff --git a/backend/server.py b/backend/server.py index 7b8a8a5a69..f14c141e12 100644 --- a/backend/server.py +++ b/backend/server.py @@ -527,7 +527,7 @@ def parameters_to_generated_image_metadata(parameters): rfc_dict["sampler"] = parameters["sampler_name"] # display weighted subprompts (liable to change) - subprompts = split_weighted_subprompts(parameters["prompt"]) + subprompts = split_weighted_subprompts(parameters["prompt"], skip_normalize=True) subprompts = [{"prompt": x[0], "weight": x[1]} for x in subprompts] rfc_dict["prompt"] = subprompts diff --git a/configs/stable-diffusion/v1-inference.yaml b/configs/stable-diffusion/v1-inference.yaml index 9c773077b6..baf91f6e26 100644 --- a/configs/stable-diffusion/v1-inference.yaml +++ b/configs/stable-diffusion/v1-inference.yaml @@ -76,4 +76,4 @@ model: target: torch.nn.Identity cond_stage_config: - target: ldm.modules.encoders.modules.FrozenCLIPEmbedder + target: ldm.modules.encoders.modules.WeightedFrozenCLIPEmbedder diff --git a/ldm/generate.py b/ldm/generate.py index 7fb68dec0a..965d37a240 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -438,7 +438,7 @@ class Generate: sampler=self.sampler, steps=steps, cfg_scale=cfg_scale, - conditioning=(uc, c), + conditioning=(uc, c), # here change to arrays ddim_eta=ddim_eta, image_callback=image_callback, # called after the final image is generated step_callback=step_callback, # called after each intermediate image is generated @@ -477,6 +477,10 @@ class Generate: print('**Interrupted** Partial results will be returned.') else: raise KeyboardInterrupt + # brute-force fallback + except Exception as e: + print(traceback.format_exc(), file=sys.stderr) + print('>> Could not generate image.') toc = time.time() print('>> Usage stats:') diff --git a/ldm/invoke/conditioning.py b/ldm/invoke/conditioning.py index fedd965a2c..9b67d5040d 100644 --- a/ldm/invoke/conditioning.py +++ b/ldm/invoke/conditioning.py @@ -12,42 +12,76 @@ log_tokenization() print out colour-coded tokens and warn if trunca import re import torch -def get_uc_and_c(prompt, model, log_tokens=False, skip_normalize=False): +from .prompt_parser import PromptParser, Fragment, Attention, Blend, Conjunction, FlattenedPrompt +from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder + + +def get_uc_and_c(prompt_string_uncleaned, model, log_tokens=False, skip_normalize=False): + # Extract Unconditioned Words From Prompt unconditioned_words = '' unconditional_regex = r'\[(.*?)\]' - unconditionals = re.findall(unconditional_regex, prompt) + unconditionals = re.findall(unconditional_regex, prompt_string_uncleaned) if len(unconditionals) > 0: unconditioned_words = ' '.join(unconditionals) # Remove Unconditioned Words From Prompt unconditional_regex_compile = re.compile(unconditional_regex) - clean_prompt = unconditional_regex_compile.sub(' ', prompt) - prompt = re.sub(' +', ' ', clean_prompt) + clean_prompt = unconditional_regex_compile.sub(' ', prompt_string_uncleaned) + prompt_string_cleaned = re.sub(' +', ' ', clean_prompt) + else: + prompt_string_cleaned = prompt_string_uncleaned - uc = model.get_learned_conditioning([unconditioned_words]) + pp = PromptParser() - # get weighted sub-prompts - weighted_subprompts = split_weighted_subprompts( - prompt, skip_normalize - ) + def build_conditioning_list(prompt_string:str): + parsed_conjunction: Conjunction = pp.parse(prompt_string) + print(f"parsed '{prompt_string}' to {parsed_conjunction}") + assert (type(parsed_conjunction) is Conjunction) + + conditioning_list = [] + def make_embeddings_for_flattened_prompt(flattened_prompt: FlattenedPrompt): + if type(flattened_prompt) is not FlattenedPrompt: + raise f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead" + fragments = [x[0] for x in flattened_prompt.children] + attention_weights = [x[1] for x in flattened_prompt.children] + print(fragments, attention_weights) + return model.get_learned_conditioning([fragments], attention_weights=[attention_weights]) + + for part,weight in zip(parsed_conjunction.prompts, parsed_conjunction.weights): + if type(part) is Blend: + blend:Blend = part + embeddings_to_blend = None + for flattened_prompt in blend.prompts: + this_embedding = make_embeddings_for_flattened_prompt(flattened_prompt) + embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat((embeddings_to_blend, this_embedding)) + blended_embeddings = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0), blend.weights, normalize=blend.normalize_weights) + conditioning_list.append((blended_embeddings, weight)) + else: + flattened_prompt: FlattenedPrompt = part + embeddings = make_embeddings_for_flattened_prompt(flattened_prompt) + conditioning_list.append((embeddings, weight)) + + return conditioning_list + + positive_conditioning_list = build_conditioning_list(prompt_string_cleaned) + negative_conditioning_list = build_conditioning_list(unconditioned_words) + + if len(negative_conditioning_list) == 0: + negative_conditioning = model.get_learned_conditioning([['']], attention_weights=[[1]]) + else: + if len(negative_conditioning_list)>1: + print("cannot do conjunctions on unconditioning for now") + negative_conditioning = negative_conditioning_list[0][0] + + #positive_conditioning_list.append((get_blend_prompts_and_weights(prompt), this_weight)) + #print("got empty_conditionining with shape", empty_conditioning.shape, "c[0][0] with shape", positive_conditioning[0][0].shape) + + # "unconditioned" means "the conditioning tensor is empty" + uc = negative_conditioning + c = positive_conditioning_list - if len(weighted_subprompts) > 1: - # i dont know if this is correct.. but it works - c = torch.zeros_like(uc) - # normalize each "sub prompt" and add it - for subprompt, weight in weighted_subprompts: - log_tokenization(subprompt, model, log_tokens, weight) - c = torch.add( - c, - model.get_learned_conditioning([subprompt]), - alpha=weight, - ) - else: # just standard 1 prompt - log_tokenization(prompt, model, log_tokens, 1) - c = model.get_learned_conditioning([prompt]) - uc = model.get_learned_conditioning([unconditioned_words]) return (uc, c) def split_weighted_subprompts(text, skip_normalize=False)->list: diff --git a/ldm/invoke/prompt_parser.py b/ldm/invoke/prompt_parser.py new file mode 100644 index 0000000000..c976918291 --- /dev/null +++ b/ldm/invoke/prompt_parser.py @@ -0,0 +1,331 @@ +import pyparsing +import pyparsing as pp +from pyparsing import original_text_for + + +class Prompt(): + + def __init__(self, parts: list): + for c in parts: + allowed_types = [Fragment, Attention, CFGScale] + if type(c) not in allowed_types: + raise PromptParser.ParsingException(f"Prompt cannot contain {type(c)}, only {allowed_types} are allowed") + self.children = parts + def __repr__(self): + return f"Prompt:{self.children}" + def __eq__(self, other): + return type(other) is Prompt and other.children == self.children + +class FlattenedPrompt(): + def __init__(self, parts: list): + # verify type correctness + for c in parts: + if type(c) is not tuple: + raise PromptParser.ParsingException( + f"FlattenedPrompt cannot contain {type(c)}, only ('text', weight) tuples are allowed") + text = c[0] + weight = c[1] + if type(text) is not str: + raise PromptParser.ParsingException(f"FlattenedPrompt cannot contain {type(c)}, only ('text', weight) tuples are allowed") + if type(weight) is not float and type(weight) is not int: + raise PromptParser.ParsingException( + f"FlattenedPrompt cannot contain {type(c)}, only ('text', weight) tuples are allowed") + # all looks good + self.children = parts + + def __repr__(self): + return f"FlattenedPrompt:{self.children}" + def __eq__(self, other): + return type(other) is FlattenedPrompt and other.children == self.children + + +class Attention(): + + def __init__(self, weight: float, children: list): + self.weight = weight + self.children = children + #print(f"A: requested attention '{children}' to {weight}") + + def __repr__(self): + return f"Attention:'{self.children}' @ {self.weight}" + def __eq__(self, other): + return type(other) is Attention and other.weight == self.weight and other.fragment == self.fragment + + +class CFGScale(): + def __init__(self, scale_factor: float, fragment: str): + self.fragment = fragment + self.scale_factor = scale_factor + #print(f"S: requested CFGScale '{fragment}' x {scale_factor}") + + def __repr__(self): + return f"CFGScale:'{self.fragment}' x {self.scale_factor}" + def __eq__(self, other): + return type(other) is CFGScale and other.scale_factor == self.scale_factor and other.fragment == self.fragment + + + +class Fragment(): + def __init__(self, text: str): + assert(type(text) is str) + self.text = text + + def __repr__(self): + return "Fragment:'"+self.text+"'" + def __eq__(self, other): + return type(other) is Fragment and other.text == self.text + +class Conjunction(): + def __init__(self, prompts: list, weights: list = None): + # force everything to be a Prompt + #print("making conjunction with", parts) + self.prompts = [x if (type(x) is Prompt or type(x) is Blend or type(x) is FlattenedPrompt) + else Prompt(x) for x in prompts] + self.weights = [1.0]*len(self.prompts) if weights is None else list(weights) + if len(self.weights) != len(self.prompts): + raise PromptParser.ParsingException(f"while parsing Conjunction: mismatched parts/weights counts {prompts}, {weights}") + self.type = 'AND' + + def __repr__(self): + return f"Conjunction:{self.prompts} | weights {self.weights}" + def __eq__(self, other): + return type(other) is Conjunction \ + and other.prompts == self.prompts \ + and other.weights == self.weights + + +class Blend(): + def __init__(self, prompts: list, weights: list[float], normalize_weights: bool=True): + #print("making Blend with prompts", prompts, "and weights", weights) + if len(prompts) != len(weights): + raise PromptParser.ParsingException(f"while parsing Blend: mismatched prompts/weights counts {prompts}, {weights}") + for c in prompts: + if type(c) is not Prompt and type(c) is not FlattenedPrompt: + raise(PromptParser.ParsingException(f"{type(c)} cannot be added to a Blend, only Prompts or FlattenedPrompts")) + # upcast all lists to Prompt objects + self.prompts = [x if (type(x) is Prompt or type(x) is FlattenedPrompt) + else Prompt(x) for x in prompts] + self.prompts = prompts + self.weights = weights + self.normalize_weights = normalize_weights + + def __repr__(self): + return f"Blend:{self.prompts} | weights {self.weights}" + def __eq__(self, other): + return other.__repr__() == self.__repr__() + + +class PromptParser(): + + class ParsingException(Exception): + pass + + def __init__(self, attention_plus_base=1.1, attention_minus_base=0.9): + + self.attention_plus_base = attention_plus_base + self.attention_minus_base = attention_minus_base + + self.root = self.build_parser_logic() + + + def parse(self, prompt: str) -> [list]: + ''' + :param prompt: The prompt string to parse + :return: a tuple + ''' + #print(f"!!parsing '{prompt}'") + + if len(prompt.strip()) == 0: + return Conjunction(prompts=[FlattenedPrompt([('', 1.0)])], weights=[1.0]) + + root = self.root.parse_string(prompt) + #print(f"'{prompt}' parsed to root", root) + #fused = fuse_fragments(parts) + #print("fused to", fused) + + return self.flatten(root[0]) + + def flatten(self, root: Conjunction): + + def fuse_fragments(items): + # print("fusing fragments in ", items) + result = [] + for x in items: + last_weight = result[-1][1] if len(result) > 0 else None + this_text = x[0] + this_weight = x[1] + if last_weight is not None and last_weight == this_weight: + last_text = result[-1][0] + result[-1] = (last_text + ' ' + this_text, last_weight) + else: + result.append(x) + return result + + def flatten_internal(node, weight_scale, results, prefix): + #print(prefix + "flattening", node, "...") + if type(node) is pp.ParseResults: + for x in node: + results = flatten_internal(x, weight_scale, results, prefix+'pr') + #print(prefix, " ParseResults expanded, results is now", results) + elif type(node) is Fragment: + results.append((node.text, float(weight_scale))) + elif type(node) is Attention: + #if node.weight < 1: + # todo: inject a blend when flattening attention with weight <1" + for c in node.children: + results = flatten_internal(c, weight_scale*node.weight, results, prefix+' ') + elif type(node) is Blend: + flattened_subprompts = [] + #print(" flattening blend with prompts", node.prompts, "weights", node.weights) + for prompt in node.prompts: + # prompt is a list + flattened_subprompts = flatten_internal(prompt, weight_scale, flattened_subprompts, prefix+'B ') + results += [Blend(prompts=flattened_subprompts, weights=node.weights)] + elif type(node) is Prompt: + #print(prefix + "about to flatten Prompt with children", node.children) + flattened_prompt = [] + for child in node.children: + flattened_prompt = flatten_internal(child, weight_scale, flattened_prompt, prefix+'P ') + results += [FlattenedPrompt(parts=fuse_fragments(flattened_prompt))] + #print(prefix + "after flattening Prompt, results is", results) + else: + raise PromptParser.ParsingException(f"unhandled node type {type(node)} when flattening {node}") + #print(prefix + "-> after flattening", type(node), "results is", results) + return results + + #print("flattening", root) + + flattened_parts = [] + for part in root.prompts: + flattened_parts += flatten_internal(part, 1.0, [], ' C| ') + weights = root.weights + return Conjunction(flattened_parts, weights) + + + + def build_parser_logic(self): + + lparen = pp.Literal("(").suppress() + rparen = pp.Literal(")").suppress() + # accepts int or float notation, always maps to float + number = pyparsing.pyparsing_common.real | pp.Word(pp.nums).set_parse_action(pp.token_map(float)) + SPACE_CHARS = ' \t\n' + + prompt_part = pp.Forward() + word = pp.Forward() + + def make_fragment(x): + #print("### making fragment for", x) + if type(x) is str: + return Fragment(x) + elif type(x) is pp.ParseResults or type(x) is list: + return Fragment(' '.join([s for s in x])) + else: + raise PromptParser.ParsingException("Cannot make fragment from " + str(x)) + + # attention control of the form +(phrase) / -(phrase) / (phrase) + # phrase can be multiple words, can have multiple +/- signs to increase the effect or type a floating point or integer weight + attention = pp.Forward() + attention_head = (number | pp.Word('+') | pp.Word('-'))\ + .set_name("attention_head")\ + .set_debug(False) + fragment_inside_attention = pp.CharsNotIn(SPACE_CHARS+'()')\ + .set_parse_action(make_fragment)\ + .set_name("fragment_inside_attention")\ + .set_debug(False) + attention_with_parens = pp.Forward() + attention_with_parens_body = pp.nested_expr(content=pp.delimited_list((attention_with_parens | fragment_inside_attention), delim=SPACE_CHARS)) + attention_with_parens << (attention_head + attention_with_parens_body) + + def make_attention(x): + # print("making Attention from parsing with args", x0, x1) + weight = 1 + # number(str) + if type(x[0]) is float or type(x[0]) is int: + weight = float(x[0]) + # +(str) or -(str) or +str or -str + elif type(x[0]) is str: + base = self.attention_plus_base if x[0][0] == '+' else self.attention_minus_base + weight = pow(base, len(x[0])) + # print("Making attention with children of type", [str(type(x)) for x in x1]) + return Attention(weight=weight, children=x[1]) + + attention_with_parens.set_parse_action(make_attention)\ + .set_name("attention_with_parens")\ + .set_debug(False) + + # attention control of the form ++word --word (no parens) + attention_without_parens = ( + (pp.Word('+') | pp.Word('-')) + + pp.CharsNotIn(SPACE_CHARS+'()').set_parse_action(lambda x: [[make_fragment(x)]]) + )\ + .set_name("attention_without_parens")\ + .set_debug(False) + attention_without_parens.set_parse_action(make_attention) + + attention << (attention_with_parens | attention_without_parens)\ + .set_name("attention")\ + .set_debug(False) + + # fragments of text with no attention control + word << pp.Word(pp.printables).set_parse_action(lambda x: Fragment(' '.join([s for s in x]))) + word.set_name("word") + word.set_debug(False) + prompt_part << (attention | word) + prompt_part.set_debug(False) + prompt_part.set_name("prompt_part") + + # root prompt definition + prompt = pp.Group(pp.OneOrMore(prompt_part))\ + .set_parse_action(lambda x: Prompt(x[0])) + + # weighted blend of prompts + # ("promptA", "promptB").blend(a, b) where "promptA" and "promptB" are valid prompts and a and b are float or + # int weights. + # can specify more terms eg ("promptA", "promptB", "promptC").blend(a,b,c) + + def make_prompt_from_quoted_string(x): + #print(' got quoted prompt', x) + + x_unquoted = x[0][1:-1] + if len(x_unquoted.strip()) == 0: + # print(' b : just an empty string') + return Prompt([Fragment('')]) + # print(' b parsing ', c_unquoted) + x_parsed = prompt.parse_string(x_unquoted) + #print(" quoted prompt was parsed to", type(x_parsed),":", x_parsed) + return x_parsed[0] + + quoted_prompt = pp.dbl_quoted_string.set_parse_action(make_prompt_from_quoted_string) + quoted_prompt.set_name('quoted_prompt') + + blend_terms = pp.delimited_list(quoted_prompt).set_name('blend_terms') + blend_weights = pp.delimited_list(number).set_name('blend_weights') + blend = pp.Group(lparen + pp.Group(blend_terms) + rparen + + pp.Literal(".blend").suppress() + + lparen + pp.Group(blend_weights) + rparen).set_name('blend') + blend.set_debug(False) + + + blend.set_parse_action(lambda x: Blend(prompts=x[0][0], weights=x[0][1])) + + conjunction_terms = blend_terms.copy().set_name('conjunction_terms') + conjunction_weights = blend_weights.copy().set_name('conjunction_weights') + conjunction_with_parens_and_quotes = pp.Group(lparen + pp.Group(conjunction_terms) + rparen + + pp.Literal(".and").suppress() + + lparen + pp.Optional(pp.Group(conjunction_weights)) + rparen).set_name('conjunction') + def make_conjunction(x): + parts_raw = x[0][0] + weights = x[0][1] if len(x[0])>1 else [1.0]*len(parts_raw) + parts = [part for part in parts_raw] + return Conjunction(parts, weights) + conjunction_with_parens_and_quotes.set_parse_action(make_conjunction) + + implicit_conjunction = pp.OneOrMore(blend | prompt) + implicit_conjunction.set_parse_action(lambda x: Conjunction(x)) + + conjunction = conjunction_with_parens_and_quotes | implicit_conjunction + conjunction.set_debug(False) + + # top-level is a conjunction of one or more blends or prompts + return conjunction diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py index f5dada8627..5120d92c48 100644 --- a/ldm/models/diffusion/ddim.py +++ b/ldm/models/diffusion/ddim.py @@ -34,23 +34,21 @@ class DDIMSampler(Sampler): b, *_, device = *x.shape, x.device if ( - unconditional_conditioning is None - or unconditional_guidance_scale == 1.0 + (unconditional_conditioning is None + or unconditional_guidance_scale == 1.0) + and c is not list ): e_t = self.model.apply_model(x, t, c) else: - x_in = torch.cat([x] * 2) - t_in = torch.cat([t] * 2) - c_in = torch.cat([unconditional_conditioning, c]) - e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) - e_t = e_t_uncond + unconditional_guidance_scale * ( - e_t - e_t_uncond - ) + e_t = self.apply_weighted_conditioning_list(x, t, self.model.apply_model, unconditional_conditioning, c, unconditional_guidance_scale) if score_corrector is not None: assert self.model.parameterization == 'eps' + if c is list and len(c)>1: + print("warning: ddim score modifier currently ignores all but the first part of the prompt conjunction, this is probably wrong") + corrector_c = [c[0][0] if c is list else c] e_t = score_corrector.modify_score( - self.model, e_t, x, t, c, **corrector_kwargs + self.model, e_t, x, t, corrector_c, **corrector_kwargs ) alphas = ( diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py index 4b62b5e393..0f55786323 100644 --- a/ldm/models/diffusion/ddpm.py +++ b/ldm/models/diffusion/ddpm.py @@ -820,13 +820,13 @@ class LatentDiffusion(DDPM): ) return self.scale_factor * z - def get_learned_conditioning(self, c): + def get_learned_conditioning(self, c, attention_weights=None): if self.cond_stage_forward is None: if hasattr(self.cond_stage_model, 'encode') and callable( self.cond_stage_model.encode ): c = self.cond_stage_model.encode( - c, embedding_manager=self.embedding_manager + c, embedding_manager=self.embedding_manager, attention_weights=attention_weights ) if isinstance(c, DiagonalGaussianDistribution): c = c.mode() diff --git a/ldm/models/diffusion/ksampler.py b/ldm/models/diffusion/ksampler.py index ac0615b30c..4d37c8cf9b 100644 --- a/ldm/models/diffusion/ksampler.py +++ b/ldm/models/diffusion/ksampler.py @@ -38,7 +38,8 @@ class CFGDenoiser(nn.Module): x_in = torch.cat([x] * 2) sigma_in = torch.cat([sigma] * 2) cond_in = torch.cat([uncond, cond]) - uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) + unconditioned_x, conditioned_x = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) + if self.warmup < self.warmup_max: thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max)) self.warmup += 1 @@ -46,7 +47,28 @@ class CFGDenoiser(nn.Module): thresh = self.threshold if thresh > self.threshold: thresh = self.threshold - return cfg_apply_threshold(uncond + (cond - uncond) * cond_scale, thresh) + + # damian0815 thinking out loud notes: + # b + (a - b)*scale + # starting at the output that emerges applying the negative prompt (by default ''), + # (-> this is why the unconditioning feels like hammer) + # move toward the positive prompt by an amount controlled by cond_scale. + return cfg_apply_threshold(unconditioned_x + (conditioned_x - unconditioned_x) * cond_scale, thresh) + + +class ProgrammableCFGDenoiser(CFGDenoiser): + def forward(self, x, sigma, uncond, cond, cond_scale): + forward_lambda = lambda x, t, c: self.inner_model(x, t, cond=c) + x_new = Sampler.apply_weighted_conditioning_list(x, sigma, forward_lambda, uncond, cond, cond_scale) + + if self.warmup < self.warmup_max: + thresh = max(1, 1 + (self.threshold - 1) * (self.warmup / self.warmup_max)) + self.warmup += 1 + else: + thresh = self.threshold + if thresh > self.threshold: + thresh = self.threshold + return cfg_apply_threshold(x_new, threshold=thresh) class KSampler(Sampler): @@ -181,7 +203,6 @@ class KSampler(Sampler): ) # sigmas are set up in make_schedule - we take the last steps items - total_steps = len(self.sigmas) sigmas = self.sigmas[-S-1:] # x_T is variation noise. When an init image is provided (in x0) we need to add @@ -194,7 +215,7 @@ class KSampler(Sampler): else: x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] - model_wrap_cfg = CFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10)) + model_wrap_cfg = ProgrammableCFGDenoiser(self.model, threshold=threshold, warmup=max(0.8*S,S-10)) extra_args = { 'cond': conditioning, 'uncond': unconditional_conditioning, diff --git a/ldm/models/diffusion/sampler.py b/ldm/models/diffusion/sampler.py index ff705513f8..42704f1175 100644 --- a/ldm/models/diffusion/sampler.py +++ b/ldm/models/diffusion/sampler.py @@ -4,6 +4,8 @@ ldm.models.diffusion.sampler Base class for ldm.models.diffusion.ddim, ldm.models.diffusion.ksampler, etc ''' +from math import ceil + import torch import numpy as np from tqdm import tqdm @@ -411,3 +413,54 @@ class Sampler(object): return self.model.inner_model.q_sample(x0,ts) ''' return self.model.q_sample(x0,ts) + + + @classmethod + def apply_weighted_conditioning_list(cls, x, t, forward_func, uc, c_or_weighted_c_list, global_guidance_scale): + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) # aka sigmas + + deltas = None + uncond_latents = None + weighted_cond_list = c_or_weighted_c_list if type(c_or_weighted_c_list) is list else [(c_or_weighted_c_list, 1)] + + # below is fugly omg + num_actual_conditionings = len(c_or_weighted_c_list) + conditionings = [uc] + [c for c,weight in weighted_cond_list] + weights = [1] + [weight for c,weight in weighted_cond_list] + chunk_count = ceil(len(conditionings)/2) + assert(len(conditionings)>=2, "need at least one uncond and one cond") + deltas = None + for chunk_index in range(chunk_count): + offset = chunk_index*2 + chunk_size = min(2, len(conditionings)-offset) + + if chunk_size == 1: + c_in = conditionings[offset] + latents_a = forward_func(x_in[:-1], t_in[:-1], c_in) + latents_b = None + else: + c_in = torch.cat(conditionings[offset:offset+2]) + latents_a, latents_b = forward_func(x_in, t_in, c_in).chunk(2) + + # first chunk is guaranteed to be 2 entries: uncond_latents + first conditioining + if chunk_index == 0: + uncond_latents = latents_a + deltas = latents_b - uncond_latents + else: + deltas = torch.cat((deltas, latents_a - uncond_latents)) + if latents_b is not None: + deltas = torch.cat((deltas, latents_b - uncond_latents)) + + # merge the weighted deltas together into a single merged delta + per_delta_weights = torch.tensor(weights[1:], dtype=deltas.dtype, device=deltas.device) + normalize = False + if normalize: + per_delta_weights /= torch.sum(per_delta_weights) + reshaped_weights = per_delta_weights.reshape(per_delta_weights.shape + (1, 1, 1)) + deltas_merged = torch.sum(deltas * reshaped_weights, dim=0, keepdim=True) + + # old_return_value = super().forward(x, sigma, uncond, cond, cond_scale) + # assert(0 == len(torch.nonzero(old_return_value - (uncond_latents + deltas_merged * cond_scale)))) + + return uncond_latents + deltas_merged * global_guidance_scale diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index 426fccced3..857a8a8e3e 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -1,3 +1,5 @@ +import math + import torch import torch.nn as nn from functools import partial @@ -454,6 +456,207 @@ class FrozenCLIPEmbedder(AbstractEncoder): def encode(self, text, **kwargs): return self(text, **kwargs) +class WeightedFrozenCLIPEmbedder(FrozenCLIPEmbedder): + + attention_weights_key = "attention_weights" + + def build_token_list_fragment(self, fragment: str, weight: float) -> (torch.Tensor, torch.Tensor): + batch_encoding = self.tokenizer( + fragment, + truncation=True, + max_length=self.max_length, + return_length=True, + return_overflowing_tokens=False, + padding='none', + return_tensors='pt', + ) + return batch_encoding, torch.ones_like(batch_encoding) * weight + + + def forward(self, text: list, **kwargs): + ''' + + :param text: A batch of prompt strings, or, a batch of lists of fragments of prompt strings to which different + weights shall be applied. + :param kwargs: If the keyword arg "attention_weights" is passed, it shall contain a batch of lists of weights + for the prompt fragments. In this case text must contain batches of lists of prompt fragments. + :return: A tensor of shape (B, 77, 768) containing weighted embeddings + ''' + if self.attention_weights_key not in kwargs: + # fallback to base class implementation + return super().forward(text, **kwargs) + + attention_weights = kwargs[self.attention_weights_key] + # self.transformer doesn't like receiving "attention_weights" as an argument + kwargs.pop(self.attention_weights_key) + + batch_z = None + for fragments, weights in zip(text, attention_weights): + + # First, weight tokens in individual fragments by scaling the feature vectors as requested (effectively + # applying a multiplier to the CFG scale on a per-token basis). + # For tokens weighted<1, intuitively we want SD to become not merely *less* interested in the concept + # captured by the fragment but actually *dis*interested in it (a 0.01 interest in "red" is still an active + # interest, however small, in redness; what the user probably intends when they attach the number 0.01 to + # "red" is to tell SD that it should almost completely *ignore* redness). + # To do this, the embedding is lerped away from base_embedding in the direction of an embedding for a prompt + # string from which the low-weighted fragment has been simply removed. The closer the weight is to zero, the + # closer the resulting embedding is to an embedding for a prompt that simply lacks this fragment. + + # handle weights >=1 + tokens, per_token_weights = self.get_tokens_and_weights(fragments, weights) + base_embedding = self.build_weighted_embedding_tensor(tokens, per_token_weights, **kwargs) + + # this is our starting point + embeddings = base_embedding.unsqueeze(0) + per_embedding_weights = [1.0] + + # now handle weights <1 + # Do this by building extra embeddings tensors that lack the words being <1 weighted. These will be lerped + # with the embeddings tensors that have the words, such that if the weight of a word is 0.5, the resulting + # embedding will be exactly half-way between the unweighted prompt and the prompt with the <1 weighted words + # removed. + # eg for "mountain:1 man:0.5", intuitively the "man" should be "half-gone". therefore, append an embedding + # for "mountain" (i.e. without "man") to the already-produced embedding for "mountain man", and weight it + # such that the resulting lerped embedding is exactly half-way between "mountain man" and "mountain". + for index, fragment_weight in enumerate(weights): + if fragment_weight < 1: + fragments_without_this = fragments[:index] + fragments[index+1:] + weights_without_this = weights[:index] + weights[index+1:] + tokens, per_token_weights = self.get_tokens_and_weights(fragments_without_this, weights_without_this) + embedding_without_this = self.build_weighted_embedding_tensor(tokens, per_token_weights, **kwargs) + + embeddings = torch.cat((embeddings, embedding_without_this.unsqueeze(0)), dim=1) + # weight of the embedding *without* this fragment gets *stronger* as its weight approaches 0 + # if fragment_weight = 0, basically we want embedding_without_this to completely overwhelm base_embedding + # therefore: + # fragment_weight = 1: we are at base_z => lerp weight 0 + # fragment_weight = 0.5: we are halfway between base_z and here => lerp weight 1 + # fragment_weight = 0: we're now entirely overriding base_z ==> lerp weight inf + # so let's use tan(), because: + # tan is 0.0 at 0, + # 1.0 at PI/4, and + # inf at PI/2 + # -> tan((1-weight)*PI/2) should give us ideal lerp weights + epsilon = 1e-9 + fragment_weight = max(epsilon, fragment_weight) # inf is bad + embedding_lerp_weight = math.tan((1.0 - fragment_weight) * math.pi / 2) + # todo handle negative weight? + + per_embedding_weights.append(embedding_lerp_weight) + + lerped_embeddings = self.apply_embedding_weights(embeddings, per_embedding_weights, normalize=True).squeeze(0) + + print(f"assembled tokens for '{fragments}' into tensor of shape {lerped_embeddings.shape}") + + # append to batch + batch_z = lerped_embeddings.unsqueeze(0) if batch_z is None else torch.cat((batch_z, lerped_embeddings.unsqueeze(0)), dim=1) + + # should have shape (B, 77, 768) + print(f"assembled all tokens into tensor of shape {batch_z.shape}") + + return batch_z + + @classmethod + def apply_embedding_weights(self, embeddings: torch.Tensor, per_embedding_weights: list[float], normalize:bool) -> torch.Tensor: + per_embedding_weights = torch.tensor(per_embedding_weights, dtype=embeddings.dtype, device=embeddings.device) + if normalize: + per_embedding_weights = per_embedding_weights / torch.sum(per_embedding_weights) + reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1, 1,)) + #reshaped_weights = per_embedding_weights.reshape(per_embedding_weights.shape + (1,1,)).expand(embeddings.shape) + return torch.sum(embeddings * reshaped_weights, dim=1) + # lerped embeddings has shape (77, 768) + + + def get_tokens_and_weights(self, fragments: list[str], weights: list[float]) -> (torch.Tensor, torch.Tensor): + ''' + + :param fragments: + :param weights: Per-fragment weights (CFG scaling). No need for these to be normalized. They will not be normalized here and that's fine. + :return: + ''' + # empty is meaningful + if len(fragments) == 0 and len(weights) == 0: + fragments = [''] + weights = [1] + item_encodings = self.tokenizer( + fragments, + truncation=True, + max_length=self.max_length, + return_overflowing_tokens=False, + padding='do_not_pad', + return_tensors=None, # just give me a list of ints + )['input_ids'] + all_tokens = [] + per_token_weights = [] + print("all fragments:", fragments, weights) + for index, fragment in enumerate(item_encodings): + weight = weights[index] + print("processing fragment", fragment, weight) + fragment_tokens = item_encodings[index] + print("fragment", fragment, "processed to", fragment_tokens) + # trim bos and eos markers before appending + all_tokens.extend(fragment_tokens[1:-1]) + per_token_weights.extend([weight] * (len(fragment_tokens) - 2)) + + if len(all_tokens) > self.max_length - 2: + print("prompt is too long and has been truncated") + all_tokens = all_tokens[:self.max_length - 2] + + # pad out to a 77-entry array: [eos_token, , eos_token, ..., eos_token] + # (77 = self.max_length) + pad_length = self.max_length - 1 - len(all_tokens) + all_tokens.insert(0, self.tokenizer.bos_token_id) + all_tokens.extend([self.tokenizer.eos_token_id] * pad_length) + per_token_weights.insert(0, 1) + per_token_weights.extend([1] * pad_length) + + all_tokens_tensor = torch.tensor(all_tokens, dtype=torch.long).to(self.device) + per_token_weights_tensor = torch.tensor(per_token_weights, dtype=torch.float32).to(self.device) + print(f"assembled all_tokens_tensor with shape {all_tokens_tensor.shape}") + return all_tokens_tensor, per_token_weights_tensor + + def build_weighted_embedding_tensor(self, tokens: torch.Tensor, per_token_weights: torch.Tensor, weight_delta_from_empty=True, **kwargs) -> torch.Tensor: + ''' + Build a tensor representing the passed-in tokens, each of which has a weight. + :param tokens: A tensor of shape (77) containing token ids (integers) + :param per_token_weights: A tensor of shape (77) containing weights (floats) + :param method: Whether to multiply the whole feature vector for each token or just its distance from an "empty" feature vector + :param kwargs: passed on to self.transformer() + :return: A tensor of shape (1, 77, 768) representing the requested weighted embeddings. + ''' + #print(f"building weighted embedding tensor for {tokens} with weights {per_token_weights}") + z = self.transformer(input_ids=tokens.unsqueeze(0), **kwargs) + batch_weights_expanded = per_token_weights.reshape(per_token_weights.shape + (1,)).expand(z.shape) + + if weight_delta_from_empty: + empty_tokens = self.tokenizer([''] * z.shape[0], + truncation=True, + max_length=self.max_length, + padding='max_length', + return_tensors='pt' + )['input_ids'].to(self.device) + empty_z = self.transformer(input_ids=empty_tokens, **kwargs) + z_delta_from_empty = z - empty_z + weighted_z = empty_z + (z_delta_from_empty * batch_weights_expanded) + + weighted_z_delta_from_empty = (weighted_z-empty_z) + print("weighted z has delta from empty with sum", weighted_z_delta_from_empty.sum().item(), "mean", weighted_z_delta_from_empty.mean().item() ) + + #print("using empty-delta method, first 5 rows:") + #print(weighted_z[:5]) + + return weighted_z + + else: + original_mean = z.mean() + z *= batch_weights_expanded + after_weighting_mean = z.mean() + # correct the mean. not sure if this is right but it's what the automatic1111 fork of SD does + mean_correction_factor = original_mean/after_weighting_mean + z *= mean_correction_factor + return z + class FrozenCLIPTextEmbedder(nn.Module): """ diff --git a/tests/test_prompt_parser.py b/tests/test_prompt_parser.py new file mode 100644 index 0000000000..207475d02e --- /dev/null +++ b/tests/test_prompt_parser.py @@ -0,0 +1,136 @@ +import unittest + +from ldm.invoke.prompt_parser import PromptParser, Blend, Conjunction, FlattenedPrompt + +def parse_prompt(prompt_string): + pp = PromptParser() + #print(f"parsing '{prompt_string}'") + parse_result = pp.parse(prompt_string) + #print(f"-> parsed '{prompt_string}' to {parse_result}") + return parse_result + +class PromptParserTestCase(unittest.TestCase): + + def test_empty(self): + self.assertEqual(Conjunction([FlattenedPrompt([('', 1)])]), parse_prompt('')) + + def test_basic(self): + self.assertEqual(Conjunction([FlattenedPrompt([('fire (flames)', 1)])]), parse_prompt("fire (flames)")) + self.assertEqual(Conjunction([FlattenedPrompt([("fire flames", 1)])]), parse_prompt("fire flames")) + self.assertEqual(Conjunction([FlattenedPrompt([("fire, flames", 1)])]), parse_prompt("fire, flames")) + self.assertEqual(Conjunction([FlattenedPrompt([("fire, flames , fire", 1)])]), parse_prompt("fire, flames , fire")) + + def test_attention(self): + self.assertEqual(Conjunction([FlattenedPrompt([('flames', 0.5)])]), parse_prompt("0.5(flames)")) + self.assertEqual(Conjunction([FlattenedPrompt([('fire flames', 0.5)])]), parse_prompt("0.5(fire flames)")) + self.assertEqual(Conjunction([FlattenedPrompt([('flames', 1.1)])]), parse_prompt("+(flames)")) + self.assertEqual(Conjunction([FlattenedPrompt([('flames', 0.9)])]), parse_prompt("-(flames)")) + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1), ('flames', 0.5)])]), parse_prompt("fire 0.5(flames)")) + self.assertEqual(Conjunction([FlattenedPrompt([('flames', pow(1.1, 2))])]), parse_prompt("++(flames)")) + self.assertEqual(Conjunction([FlattenedPrompt([('flames', pow(0.9, 2))])]), parse_prompt("--(flames)")) + self.assertEqual(Conjunction([FlattenedPrompt([('flowers', pow(0.9, 3)), ('flames', pow(1.1, 3))])]), parse_prompt("---(flowers) +++flames")) + self.assertEqual(Conjunction([FlattenedPrompt([('flowers', pow(0.9, 3)), ('flames', pow(1.1, 3))])]), parse_prompt("---(flowers) +++flames")) + self.assertEqual(Conjunction([FlattenedPrompt([('flowers', pow(0.9, 3)), ('flames+', pow(1.1, 3))])]), + parse_prompt("---(flowers) +++flames+")) + self.assertEqual(Conjunction([FlattenedPrompt([('pretty flowers', 1.1)])]), + parse_prompt("+(pretty flowers)")) + self.assertEqual(Conjunction([FlattenedPrompt([('pretty flowers', 1.1), (', the flames are too hot', 1)])]), + parse_prompt("+(pretty flowers), the flames are too hot")) + + def test_no_parens_attention_runon(self): + self.assertEqual(Conjunction([FlattenedPrompt([('fire', pow(1.1, 2)), ('flames', 1.0)])]), parse_prompt("++fire flames")) + self.assertEqual(Conjunction([FlattenedPrompt([('fire', pow(0.9, 2)), ('flames', 1.0)])]), parse_prompt("--fire flames")) + self.assertEqual(Conjunction([FlattenedPrompt([('flowers', 1.0), ('fire', pow(1.1, 2)), ('flames', 1.0)])]), parse_prompt("flowers ++fire flames")) + self.assertEqual(Conjunction([FlattenedPrompt([('flowers', 1.0), ('fire', pow(0.9, 2)), ('flames', 1.0)])]), parse_prompt("flowers --fire flames")) + + + def test_explicit_conjunction(self): + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])]), parse_prompt('("fire", "flames").and(1,1)')) + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])]), parse_prompt('("fire", "flames").and()')) + self.assertEqual( + Conjunction([FlattenedPrompt([('fire flames', 1.0)]), FlattenedPrompt([('mountain man', 1.0)])]), parse_prompt('("fire flames", "mountain man").and()')) + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 2.0)]), FlattenedPrompt([('flames', 0.9)])]), parse_prompt('("2.0(fire)", "-flames").and()')) + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)]), + FlattenedPrompt([('mountain man', 1.0)])]), parse_prompt('("fire", "flames", "mountain man").and()')) + + def test_conjunction_weights(self): + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])], weights=[2.0,1.0]), parse_prompt('("fire", "flames").and(2,1)')) + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('flames', 1.0)])], weights=[1.0,2.0]), parse_prompt('("fire", "flames").and(1,2)')) + + with self.assertRaises(PromptParser.ParsingException): + parse_prompt('("fire", "flames").and(2)') + parse_prompt('("fire", "flames").and(2,1,2)') + + def test_complex_conjunction(self): + self.assertEqual(Conjunction([FlattenedPrompt([("mountain man", 1.0)]), FlattenedPrompt([("a person with a hat", 1.0), ("riding a bicycle", pow(1.1,2))])], weights=[0.5, 0.5]), + parse_prompt("(\"mountain man\", \"a person with a hat ++(riding a bicycle)\").and(0.5, 0.5)")) + + def test_badly_formed(self): + def make_untouched_prompt(prompt): + return Conjunction([FlattenedPrompt([(prompt, 1.0)])]) + + def assert_if_prompt_string_not_untouched(prompt): + self.assertEqual(make_untouched_prompt(prompt), parse_prompt(prompt)) + + assert_if_prompt_string_not_untouched('a test prompt') + assert_if_prompt_string_not_untouched('a badly (formed test prompt') + assert_if_prompt_string_not_untouched('a badly formed test+ prompt') + assert_if_prompt_string_not_untouched('a badly (formed test+ prompt') + assert_if_prompt_string_not_untouched('a badly (formed test+ )prompt') + assert_if_prompt_string_not_untouched('a badly (formed test+ )prompt') + assert_if_prompt_string_not_untouched('(((a badly (formed test+ )prompt') + assert_if_prompt_string_not_untouched('(a (ba)dly (f)ormed test+ prompt') + self.assertEqual(Conjunction([FlattenedPrompt([('(a (ba)dly (f)ormed test+', 1.0), ('prompt', 1.1)])]), + parse_prompt('(a (ba)dly (f)ormed test+ +prompt')) + self.assertEqual(Conjunction([Blend([FlattenedPrompt([('((a badly (formed test+', 1.0)])], weights=[1.0])]), + parse_prompt('("((a badly (formed test+ ").blend(1.0)')) + + def test_blend(self): + self.assertEqual(Conjunction( + [Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('fire flames', 1.0)])], [0.7, 0.3])]), + parse_prompt("(\"fire\", \"fire flames\").blend(0.7, 0.3)") + ) + self.assertEqual(Conjunction([Blend( + [FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('fire flames', 1.0)]), FlattenedPrompt([('hi', 1.0)])], + [0.7, 0.3, 1.0])]), + parse_prompt("(\"fire\", \"fire flames\", \"hi\").blend(0.7, 0.3, 1.0)") + ) + self.assertEqual(Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), + FlattenedPrompt([('fire flames', 1.0), ('hot', pow(1.1, 2))]), + FlattenedPrompt([('hi', 1.0)])], + weights=[0.7, 0.3, 1.0])]), + parse_prompt("(\"fire\", \"fire flames ++(hot)\", \"hi\").blend(0.7, 0.3, 1.0)") + ) + # blend a single entry is not a failure + self.assertEqual(Conjunction([Blend([FlattenedPrompt([('fire', 1.0)])], [0.7])]), + parse_prompt("(\"fire\").blend(0.7)") + ) + # blend with empty + self.assertEqual( + Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]), + parse_prompt("(\"fire\", \"\").blend(0.7, 1)") + ) + self.assertEqual( + Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]), + parse_prompt("(\"fire\", \" \").blend(0.7, 1)") + ) + self.assertEqual( + Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([('', 1.0)])], [0.7, 1.0])]), + parse_prompt("(\"fire\", \" \").blend(0.7, 1)") + ) + self.assertEqual( + Conjunction([Blend([FlattenedPrompt([('fire', 1.0)]), FlattenedPrompt([(',', 1.0)])], [0.7, 1.0])]), + parse_prompt("(\"fire\", \" , \").blend(0.7, 1)") + ) + + + def test_nested(self): + self.assertEqual(Conjunction([FlattenedPrompt([('fire', 1.0), ('flames', 2.0), ('trees', 3.0)])]), + parse_prompt('fire 2.0(flames 1.5(trees))')) + self.assertEqual(Conjunction([Blend(prompts=[FlattenedPrompt([('fire', 1.0), ('flames', 1.2100000000000002)]), + FlattenedPrompt([('mountain', 1.0), ('man', 2.0)])], + weights=[1.0, 1.0])]), + parse_prompt('("fire ++(flames)", "mountain 2(man)").blend(1,1)')) + +if __name__ == '__main__': + unittest.main()