From 44e4090909c0375f9344f647e85a2107f6b1ded0 Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Mon, 24 Oct 2022 11:16:52 +0200 Subject: [PATCH] re-enable legacy blend syntax --- backend/invoke_ai_web_server.py | 2 +- backend/server.py | 2 +- ldm/invoke/args.py | 2 +- ldm/invoke/conditioning.py | 70 ++++---------------------- ldm/invoke/prompt_parser.py | 88 ++++++++++++++++++++++++++++++--- tests/test_prompt_parser.py | 41 ++++++++++++++- 6 files changed, 134 insertions(+), 71 deletions(-) diff --git a/backend/invoke_ai_web_server.py b/backend/invoke_ai_web_server.py index 96ecda1af1..dabe072f80 100644 --- a/backend/invoke_ai_web_server.py +++ b/backend/invoke_ai_web_server.py @@ -14,7 +14,7 @@ from threading import Event from ldm.invoke.args import Args, APP_ID, APP_VERSION, calculate_init_img_hash from ldm.invoke.pngwriter import PngWriter, retrieve_metadata -from ldm.invoke.conditioning import split_weighted_subprompts +from ldm.invoke.prompt_parser import split_weighted_subprompts from backend.modules.parameters import parameters_to_command diff --git a/backend/server.py b/backend/server.py index 7b8a8a5a69..8ad861356c 100644 --- a/backend/server.py +++ b/backend/server.py @@ -33,7 +33,7 @@ from ldm.generate import Generate from ldm.invoke.restoration import Restoration from ldm.invoke.pngwriter import PngWriter, retrieve_metadata from ldm.invoke.args import APP_ID, APP_VERSION, calculate_init_img_hash -from ldm.invoke.conditioning import split_weighted_subprompts +from ldm.invoke.prompt_parser import split_weighted_subprompts from modules.parameters import parameters_to_command diff --git a/ldm/invoke/args.py b/ldm/invoke/args.py index 26920f28ea..12e9f96f6b 100644 --- a/ldm/invoke/args.py +++ b/ldm/invoke/args.py @@ -92,7 +92,7 @@ import copy import base64 import functools import ldm.invoke.pngwriter -from ldm.invoke.conditioning import split_weighted_subprompts +from ldm.invoke.prompt_parser import split_weighted_subprompts SAMPLER_CHOICES = [ 'ddim', diff --git a/ldm/invoke/conditioning.py b/ldm/invoke/conditioning.py index 52d40312ac..65459b5c5f 100644 --- a/ldm/invoke/conditioning.py +++ b/ldm/invoke/conditioning.py @@ -41,9 +41,15 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n pp = PromptParser() - # we don't support conjunctions for now - parsed_prompt: Union[FlattenedPrompt, Blend] = pp.parse(prompt_string_cleaned).prompts[0] - parsed_negative_prompt: FlattenedPrompt = pp.parse(unconditioned_words).prompts[0] + parsed_prompt: Union[FlattenedPrompt, Blend] = None + legacy_blend: Blend = pp.parse_legacy_blend(prompt_string_cleaned) + if legacy_blend is not None: + parsed_prompt = legacy_blend + else: + # we don't support conjunctions for now + parsed_prompt = pp.parse_conjunction(prompt_string_cleaned).prompts[0] + + parsed_negative_prompt: FlattenedPrompt = pp.parse_conjunction(unconditioned_words).prompts[0] print("parsed prompt to", parsed_prompt) conditioning = None @@ -146,61 +152,3 @@ def get_tokens_length(model, fragments: list[Fragment]): return sum([len(x) for x in tokens]) -def split_weighted_subprompts(text, skip_normalize=False)->list: - """ - grabs all text up to the first occurrence of ':' - uses the grabbed text as a sub-prompt, and takes the value following ':' as weight - if ':' has no value defined, defaults to 1.0 - repeats until no text remaining - """ - prompt_parser = re.compile(""" - (?P # capture group for 'prompt' - (?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:' - ) # end 'prompt' - (?: # non-capture group - :+ # match one or more ':' characters - (?P # capture group for 'weight' - -?\d+(?:\.\d+)? # match positive or negative integer or decimal number - )? # end weight capture group, make optional - \s* # strip spaces after weight - | # OR - $ # else, if no ':' then match end of line - ) # end non-capture group - """, re.VERBOSE) - parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float( - match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)] - if skip_normalize: - return parsed_prompts - weight_sum = sum(map(lambda x: x[1], parsed_prompts)) - if weight_sum == 0: - print( - "Warning: Subprompt weights add up to zero. Discarding and using even weights instead.") - equal_weight = 1 / max(len(parsed_prompts), 1) - return [(x[0], equal_weight) for x in parsed_prompts] - return [(x[0], x[1] / weight_sum) for x in parsed_prompts] - -# shows how the prompt is tokenized -# usually tokens have '' to indicate end-of-word, -# but for readability it has been replaced with ' ' -def log_tokenization(text, model, log=False, weight=1): - if not log: - return - tokens = model.cond_stage_model.tokenizer._tokenize(text) - tokenized = "" - discarded = "" - usedTokens = 0 - totalTokens = len(tokens) - for i in range(0, totalTokens): - token = tokens[i].replace('', ' ') - # alternate color - s = (usedTokens % 6) + 1 - if i < model.cond_stage_model.max_length: - tokenized = tokenized + f"\x1b[0;3{s};40m{token}" - usedTokens += 1 - else: # over max token length - discarded = discarded + f"\x1b[0;3{s};40m{token}" - print(f"\n>> Tokens ({usedTokens}), Weight ({weight:.2f}):\n{tokenized}\x1b[0m") - if discarded != "": - print( - f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m" - ) diff --git a/ldm/invoke/prompt_parser.py b/ldm/invoke/prompt_parser.py index 830c5313e3..3a96d664f0 100644 --- a/ldm/invoke/prompt_parser.py +++ b/ldm/invoke/prompt_parser.py @@ -1,6 +1,6 @@ import string -from typing import Union - +from typing import Union, Optional +import re import pyparsing as pp class Prompt(): @@ -223,10 +223,10 @@ class PromptParser(): def __init__(self, attention_plus_base=1.1, attention_minus_base=0.9): - self.root = build_parser_syntax(attention_plus_base, attention_minus_base) + self.conjunction, self.prompt = build_parser_syntax(attention_plus_base, attention_minus_base) - def parse(self, prompt: str) -> Conjunction: + def parse_conjunction(self, prompt: str) -> Conjunction: ''' :param prompt: The prompt string to parse :return: a Conjunction representing the parsed results. @@ -236,13 +236,25 @@ class PromptParser(): if len(prompt.strip()) == 0: return Conjunction(prompts=[FlattenedPrompt([('', 1.0)])], weights=[1.0]) - root = self.root.parse_string(prompt) + root = self.conjunction.parse_string(prompt) #print(f"'{prompt}' parsed to root", root) #fused = fuse_fragments(parts) #print("fused to", fused) return self.flatten(root[0]) + def parse_legacy_blend(self, text: str) -> Optional[Blend]: + weighted_subprompts = split_weighted_subprompts(text, skip_normalize=False) + if len(weighted_subprompts) == 1: + return None + strings = [x[0] for x in weighted_subprompts] + weights = [x[1] for x in weighted_subprompts] + + parsed_conjunctions = [self.parse_conjunction(x) for x in strings] + flattened_prompts = [x.prompts[0] for x in parsed_conjunctions] + + return Blend(prompts=flattened_prompts, weights=weights, normalize_weights=True) + def flatten(self, root: Conjunction) -> Conjunction: """ @@ -596,4 +608,68 @@ def build_parser_syntax(attention_plus_base: float, attention_minus_base: float) conjunction.set_debug(False) # top-level is a conjunction of one or more blends or prompts - return conjunction + return conjunction, prompt + + + +def split_weighted_subprompts(text, skip_normalize=False)->list: + """ + Legacy blend parsing. + + grabs all text up to the first occurrence of ':' + uses the grabbed text as a sub-prompt, and takes the value following ':' as weight + if ':' has no value defined, defaults to 1.0 + repeats until no text remaining + """ + prompt_parser = re.compile(""" + (?P # capture group for 'prompt' + (?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:' + ) # end 'prompt' + (?: # non-capture group + :+ # match one or more ':' characters + (?P # capture group for 'weight' + -?\d+(?:\.\d+)? # match positive or negative integer or decimal number + )? # end weight capture group, make optional + \s* # strip spaces after weight + | # OR + $ # else, if no ':' then match end of line + ) # end non-capture group + """, re.VERBOSE) + parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float( + match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)] + if skip_normalize: + return parsed_prompts + weight_sum = sum(map(lambda x: x[1], parsed_prompts)) + if weight_sum == 0: + print( + "Warning: Subprompt weights add up to zero. Discarding and using even weights instead.") + equal_weight = 1 / max(len(parsed_prompts), 1) + return [(x[0], equal_weight) for x in parsed_prompts] + return [(x[0], x[1] / weight_sum) for x in parsed_prompts] + + +# shows how the prompt is tokenized +# usually tokens have '' to indicate end-of-word, +# but for readability it has been replaced with ' ' +def log_tokenization(text, model, log=False, weight=1): + if not log: + return + tokens = model.cond_stage_model.tokenizer._tokenize(text) + tokenized = "" + discarded = "" + usedTokens = 0 + totalTokens = len(tokens) + for i in range(0, totalTokens): + token = tokens[i].replace('', 'x` ') + # alternate color + s = (usedTokens % 6) + 1 + if i < model.cond_stage_model.max_length: + tokenized = tokenized + f"\x1b[0;3{s};40m{token}" + usedTokens += 1 + else: # over max token length + discarded = discarded + f"\x1b[0;3{s};40m{token}" + print(f"\n>> Tokens ({usedTokens}), Weight ({weight:.2f}):\n{tokenized}\x1b[0m") + if discarded != "": + print( + f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m" + ) diff --git a/tests/test_prompt_parser.py b/tests/test_prompt_parser.py index 0c4d9106db..486265d2f5 100644 --- a/tests/test_prompt_parser.py +++ b/tests/test_prompt_parser.py @@ -9,7 +9,7 @@ from ldm.invoke.prompt_parser import PromptParser, Blend, Conjunction, Flattened def parse_prompt(prompt_string): pp = PromptParser() #print(f"parsing '{prompt_string}'") - parse_result = pp.parse(prompt_string) + parse_result = pp.parse_conjunction(prompt_string) #print(f"-> parsed '{prompt_string}' to {parse_result}") return parse_result @@ -351,6 +351,45 @@ class PromptParserTestCase(unittest.TestCase): self.assertEqual(Conjunction([FlattenedPrompt([('mountain', 1), CrossAttentionControlSubstitute([Fragment('(((', 1)], [Fragment('m(on))key', 1)])])]), parse_prompt('mountain (\(\(\().swap(m\(on\)\)key)')) + def test_legacy_blend(self): + pp = PromptParser() + + self.assertEqual(Blend([FlattenedPrompt([('mountain man', 1)]), + FlattenedPrompt([('man mountain', 1)])], + weights=[0.5,0.5]), + pp.parse_legacy_blend('mountain man:1 man mountain:1')) + + self.assertEqual(Blend([FlattenedPrompt([('mountain', 1.1), ('man', 1)]), + FlattenedPrompt([('man', 1), ('mountain', 0.9)])], + weights=[0.5,0.5]), + pp.parse_legacy_blend('mountain+ man:1 man mountain-:1')) + + self.assertEqual(Blend([FlattenedPrompt([('mountain', 1.1), ('man', 1)]), + FlattenedPrompt([('man', 1), ('mountain', 0.9)])], + weights=[0.5,0.5]), + pp.parse_legacy_blend('mountain+ man:1 man mountain-')) + + self.assertEqual(Blend([FlattenedPrompt([('mountain', 1.1), ('man', 1)]), + FlattenedPrompt([('man', 1), ('mountain', 0.9)])], + weights=[0.5,0.5]), + pp.parse_legacy_blend('mountain+ man: man mountain-:')) + + self.assertEqual(Blend([FlattenedPrompt([('mountain man', 1)]), + FlattenedPrompt([('man mountain', 1)])], + weights=[0.75,0.25]), + pp.parse_legacy_blend('mountain man:3 man mountain:1')) + + self.assertEqual(Blend([FlattenedPrompt([('mountain man', 1)]), + FlattenedPrompt([('man mountain', 1)])], + weights=[1.0,0.0]), + pp.parse_legacy_blend('mountain man:3 man mountain:0')) + + self.assertEqual(Blend([FlattenedPrompt([('mountain man', 1)]), + FlattenedPrompt([('man mountain', 1)])], + weights=[0.8,0.2]), + pp.parse_legacy_blend('"mountain man":4 man mountain')) + + def test_single(self): # todo handle this #self.assertEqual(make_basic_conjunction(['a badly formed +test prompt']),