From c52ba1b022ef83ddbae5cde228e08bccd1087c03 Mon Sep 17 00:00:00 2001 From: _nderscore <_@nderscore.com> Date: Wed, 31 Aug 2022 18:00:10 +0000 Subject: [PATCH] feat: simplify and enhance prompt weight splitting (#258) * feat: simplify and enhance prompt weight splitting * fix: don't shadow the prompt variable * feat: enable backslash-escaped colons in prompts --- ldm/simplet2i.py | 83 +++++++++++++++++++----------------------------- 1 file changed, 32 insertions(+), 51 deletions(-) diff --git a/ldm/simplet2i.py b/ldm/simplet2i.py index d969ac5e23..82839db875 100644 --- a/ldm/simplet2i.py +++ b/ldm/simplet2i.py @@ -487,22 +487,19 @@ class T2I: uc = self.model.get_learned_conditioning(['']) - # weighted sub-prompts - subprompts, weights = T2I._split_weighted_subprompts(prompt) - if len(subprompts) > 1: + # get weighted sub-prompts + weighted_subprompts = T2I._split_weighted_subprompts(prompt, skip_normalize) + + if len(weighted_subprompts) > 1: # i dont know if this is correct.. but it works c = torch.zeros_like(uc) - # get total weight for normalizing - totalWeight = sum(weights) # normalize each "sub prompt" and add it - for i in range(0, len(subprompts)): - weight = weights[i] - if not skip_normalize: - weight = weight / totalWeight - self._log_tokenization(subprompts[i]) + for i in range(0, len(weighted_subprompts)): + subprompt, weight = weighted_subprompts[i] + self._log_tokenization(subprompt) c = torch.add( c, - self.model.get_learned_conditioning([subprompts[i]]), + self.model.get_learned_conditioning([subprompt]), alpha=weight, ) else: # just standard 1 prompt @@ -616,52 +613,36 @@ class T2I: image = torch.from_numpy(image) return 2.0 * image - 1.0 - def _split_weighted_subprompts(text): + def _split_weighted_subprompts(text, skip_normalize=False): """ grabs all text up to the first occurrence of ':' uses the grabbed text as a sub-prompt, and takes the value following ':' as weight if ':' has no value defined, defaults to 1.0 repeats until no text remaining """ - remaining = len(text) - prompts = [] - weights = [] - while remaining > 0: - if ':' in text: - idx = text.index(':') # first occurrence from start - # grab up to index as sub-prompt - prompt = text[:idx] - remaining -= idx - # remove from main text - text = text[idx + 1 :] - # find value for weight - if ' ' in text: - idx = text.index(' ') # first occurence - else: # no space, read to end - idx = len(text) - if idx != 0: - try: - weight = float(text[:idx]) - except: # couldn't treat as float - print( - f"Warning: '{text[:idx]}' is not a value, are you missing a space?" - ) - weight = 1.0 - else: # no value found - weight = 1.0 - # remove from main text - remaining -= idx - text = text[idx + 1 :] - # append the sub-prompt and its weight - prompts.append(prompt) - weights.append(weight) - else: # no : found - if len(text) > 0: # there is still text though - # take remainder as weight 1 - prompts.append(text) - weights.append(1.0) - remaining = 0 - return prompts, weights + prompt_parser = re.compile(""" + (?P # capture group for 'prompt' + (?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:' + ) # end 'prompt' + (?: # non-capture group + :+ # match one or more ':' characters + (?P # capture group for 'weight' + -?\d+(?:\.\d+)? # match positive or negative integer or decimal number + )? # end weight capture group, make optional + \s* # strip spaces after weight + | # OR + $ # else, if no ':' then match end of line + ) # end non-capture group + """, re.VERBOSE) + parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float(match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)] + if skip_normalize: + return parsed_prompts + weight_sum = sum(map(lambda x: x[1], parsed_prompts)) + if weight_sum == 0: + print("Warning: Subprompt weights add up to zero. Discarding and using even weights instead.") + equal_weight = 1 / len(parsed_prompts) + return [(x[0], equal_weight) for x in parsed_prompts] + return [(x[0], x[1] / weight_sum) for x in parsed_prompts] # shows how the prompt is tokenized # usually tokens have '' to indicate end-of-word,