From ced9c83e96266b4e0f75fd0376eee9375777b534 Mon Sep 17 00:00:00 2001 From: Damian at mba Date: Sun, 30 Oct 2022 23:01:05 +0100 Subject: [PATCH] various prompting fixes --- ldm/invoke/conditioning.py | 53 ++++++++++--------- ldm/invoke/prompt_parser.py | 32 ++++++++--- .../diffusion/shared_invokeai_diffusion.py | 5 +- scripts/invoke.py | 7 ++- 4 files changed, 59 insertions(+), 38 deletions(-) diff --git a/ldm/invoke/conditioning.py b/ldm/invoke/conditioning.py index 9b775823aa..04fbd7c10a 100644 --- a/ldm/invoke/conditioning.py +++ b/ldm/invoke/conditioning.py @@ -1,12 +1,9 @@ ''' -This module handles the generation of the conditioning tensors, including management of -weighted subprompts. +This module handles the generation of the conditioning tensors. Useful function exports: get_uc_and_c_and_ec() get the conditioned and unconditioned latent, and edited conditioning if we're doing cross-attention control -split_weighted_subpromopts() split subprompts, normalize and weight them -log_tokenization() print out colour-coded tokens and warn if truncated ''' import re @@ -16,7 +13,7 @@ from typing import Union import torch from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \ - CrossAttentionControlledFragment, CrossAttentionControlSubstitute, CrossAttentionControlAppend, Fragment + CrossAttentionControlledFragment, CrossAttentionControlSubstitute, Fragment, log_tokenization from ..models.diffusion.cross_attention_control import CrossAttentionControl from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder @@ -58,8 +55,11 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n if type(parsed_prompt) is Blend: blend: Blend = parsed_prompt embeddings_to_blend = None - for flattened_prompt in blend.prompts: - this_embedding, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt, log_tokens=log_tokens) + for i,flattened_prompt in enumerate(blend.prompts): + this_embedding, _ = build_embeddings_and_tokens_for_flattened_prompt(model, + flattened_prompt, + log_tokens=log_tokens, + log_display_label=f"(blend part {i+1}, weight={blend.weights[i]})" ) embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat( (embeddings_to_blend, this_embedding)) conditioning = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0), @@ -103,14 +103,20 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n edit_options.append(None) original_token_count += count edited_token_count += count - original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, original_prompt, log_tokens=log_tokens) + original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, + original_prompt, + log_tokens=log_tokens, + log_display_label="(.swap originals)") # naïvely building a single edited_embeddings like this disregards the effects of changing the absolute location of # subsequent tokens when there is >1 edit and earlier edits change the total token count. # eg "a cat.swap(smiling dog, s_start=0.5) eating a hotdog.swap(pizza)" - when the 'pizza' edit is active but the # 'cat' edit is not, the 'pizza' feature vector will nevertheless be affected by the introduction of the extra # token 'smiling' in the inactive 'cat' edit. # todo: build multiple edited_embeddings, one for each edit, and pass just the edited fragments through to the CrossAttentionControl functions - edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, edited_prompt, log_tokens=log_tokens) + edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, + edited_prompt, + log_tokens=log_tokens, + log_display_label="(.swap replacements)") conditioning = original_embeddings edited_conditioning = edited_embeddings @@ -121,9 +127,15 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n edit_options = edit_options ) else: - conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt, log_tokens=log_tokens) + conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, + flattened_prompt, + log_tokens=log_tokens, + log_display_label="(prompt)") - unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, parsed_negative_prompt, log_tokens=log_tokens) + unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, + parsed_negative_prompt, + log_tokens=log_tokens, + log_display_label="(unconditioning)") if isinstance(conditioning, dict): # hybrid conditioning is in play unconditioning, conditioning = flatten_hybrid_conditioning(unconditioning, conditioning) @@ -144,26 +156,15 @@ def build_token_edit_opcodes(original_tokens, edited_tokens): return SequenceMatcher(None, original_tokens, edited_tokens).get_opcodes() -def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt, log_tokens: bool=False): +def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt, log_tokens: bool=False, log_display_label: str=None): if type(flattened_prompt) is not FlattenedPrompt: raise Exception(f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead") fragments = [x.text for x in flattened_prompt.children] weights = [x.weight for x in flattened_prompt.children] embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True, fragment_weights=[weights]) - if not flattened_prompt.is_empty and log_tokens: - start_token = model.cond_stage_model.tokenizer.bos_token_id - end_token = model.cond_stage_model.tokenizer.eos_token_id - tokens_list = tokens[0].tolist() - if tokens_list[0] == start_token: - tokens_list[0] = '' - try: - first_end_token_index = tokens_list.index(end_token) - tokens_list[first_end_token_index] = '' - tokens_list = tokens_list[:first_end_token_index+1] - except ValueError: - pass - - print(f">> Prompt fragments {fragments}, tokenized to \n{tokens_list}") + if log_tokens: + text = " ".join(fragments) + log_tokenization(text, model, display_label=log_display_label) return embeddings, tokens diff --git a/ldm/invoke/prompt_parser.py b/ldm/invoke/prompt_parser.py index e6856761ac..0b3890597f 100644 --- a/ldm/invoke/prompt_parser.py +++ b/ldm/invoke/prompt_parser.py @@ -2,6 +2,19 @@ import string from typing import Union, Optional import re import pyparsing as pp +''' +This module parses prompt strings and produces tree-like structures that can be used generate and control the conditioning tensors. +weighted subprompts. + +Useful class exports: + +PromptParser - parses prompts + +Useful function exports: + +split_weighted_subpromopts() split subprompts, normalize and weight them +log_tokenization() print out colour-coded tokens and warn if truncated +''' class Prompt(): """ @@ -205,12 +218,17 @@ class Blend(): #print("making Blend with prompts", prompts, "and weights", weights) if len(prompts) != len(weights): raise PromptParser.ParsingException(f"while parsing Blend: mismatched prompts/weights counts {prompts}, {weights}") - for c in prompts: - if type(c) is not Prompt and type(c) is not FlattenedPrompt: - raise(PromptParser.ParsingException(f"{type(c)} cannot be added to a Blend, only Prompts or FlattenedPrompts")) + for p in prompts: + if type(p) is not Prompt and type(p) is not FlattenedPrompt: + raise(PromptParser.ParsingException(f"{type(p)} cannot be added to a Blend, only Prompts or FlattenedPrompts")) + for f in p.children: + if isinstance(f, CrossAttentionControlSubstitute): + raise(PromptParser.ParsingException(f"while parsing Blend: sorry, you cannot do .swap() as part of a Blend")) + # upcast all lists to Prompt objects self.prompts = [x if (type(x) is Prompt or type(x) is FlattenedPrompt) - else Prompt(x) for x in prompts] + else Prompt(x) + for x in prompts] self.prompts = prompts self.weights = weights self.normalize_weights = normalize_weights @@ -662,9 +680,7 @@ def split_weighted_subprompts(text, skip_normalize=False)->list: # shows how the prompt is tokenized # usually tokens have '' to indicate end-of-word, # but for readability it has been replaced with ' ' -def log_tokenization(text, model, log=False, weight=1): - if not log: - return +def log_tokenization(text, model, display_label=None): tokens = model.cond_stage_model.tokenizer._tokenize(text) tokenized = "" discarded = "" @@ -679,7 +695,7 @@ def log_tokenization(text, model, log=False, weight=1): usedTokens += 1 else: # over max token length discarded = discarded + f"\x1b[0;3{s};40m{token}" - print(f"\n>> Tokens ({usedTokens}), Weight ({weight:.2f}):\n{tokenized}\x1b[0m") + print(f"\n>> Tokens {display_label or ''} ({usedTokens}):\n{tokenized}\x1b[0m") if discarded != "": print( f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m" diff --git a/ldm/models/diffusion/shared_invokeai_diffusion.py b/ldm/models/diffusion/shared_invokeai_diffusion.py index 90c8ebb981..dd2643cd0a 100644 --- a/ldm/models/diffusion/shared_invokeai_diffusion.py +++ b/ldm/models/diffusion/shared_invokeai_diffusion.py @@ -157,10 +157,11 @@ class InvokeAIDiffuserComponent: # percent_through will never reach 1.0 (but this is intended) return float(step_index) / float(self.cross_attention_control_context.step_count) # find the best possible index of the current sigma in the sigma sequence - sigma_index = torch.nonzero(self.model.sigmas <= sigma)[-1] + smaller_sigmas = torch.nonzero(self.model.sigmas <= sigma) + sigma_index = smaller_sigmas[-1].item() if smaller_sigmas.shape[0] > 0 else 0 # flip because sigmas[0] is for the fully denoised image # percent_through must be <1 - return 1.0 - float(sigma_index.item() + 1) / float(self.model.sigmas.shape[0]) + return 1.0 - float(sigma_index + 1) / float(self.model.sigmas.shape[0]) # print('estimated percent_through', percent_through, 'from sigma', sigma.item()) diff --git a/scripts/invoke.py b/scripts/invoke.py index 62e7b5dea1..8b2519515b 100644 --- a/scripts/invoke.py +++ b/scripts/invoke.py @@ -10,6 +10,9 @@ import warnings import time import traceback import yaml + +from ldm.invoke.prompt_parser import PromptParser + sys.path.append('.') # corrects a weird problem on Macs from ldm.invoke.readline import get_completer from ldm.invoke.args import Args, metadata_dumps, metadata_from_png, dream_cmd_from_png @@ -18,7 +21,7 @@ from ldm.invoke.image_util import make_grid from ldm.invoke.log import write_log from omegaconf import OmegaConf from pathlib import Path -from pyparsing import ParseException +import pyparsing # global used in multiple functions (fix) infile = None @@ -340,7 +343,7 @@ def main_loop(gen, opt): catch_interrupts=catch_ctrl_c, **vars(opt) ) - except ParseException as e: + except (PromptParser.ParsingException, pyparsing.ParseException) as e: print('** An error occurred while processing your prompt **') print(f'** {str(e)} **') elif operation == 'postprocess':