various prompting fixes

This commit is contained in:
Damian at mba 2022-10-30 23:01:05 +01:00 committed by Lincoln Stein
parent 80f2cfe3e3
commit ced9c83e96
4 changed files with 59 additions and 38 deletions

View File

@ -1,12 +1,9 @@
'''
This module handles the generation of the conditioning tensors, including management of
weighted subprompts.
This module handles the generation of the conditioning tensors.
Useful function exports:
get_uc_and_c_and_ec() get the conditioned and unconditioned latent, and edited conditioning if we're doing cross-attention control
split_weighted_subpromopts() split subprompts, normalize and weight them
log_tokenization() print out colour-coded tokens and warn if truncated
'''
import re
@ -16,7 +13,7 @@ from typing import Union
import torch
from .prompt_parser import PromptParser, Blend, FlattenedPrompt, \
CrossAttentionControlledFragment, CrossAttentionControlSubstitute, CrossAttentionControlAppend, Fragment
CrossAttentionControlledFragment, CrossAttentionControlSubstitute, Fragment, log_tokenization
from ..models.diffusion.cross_attention_control import CrossAttentionControl
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder
@ -58,8 +55,11 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
if type(parsed_prompt) is Blend:
blend: Blend = parsed_prompt
embeddings_to_blend = None
for flattened_prompt in blend.prompts:
this_embedding, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt, log_tokens=log_tokens)
for i,flattened_prompt in enumerate(blend.prompts):
this_embedding, _ = build_embeddings_and_tokens_for_flattened_prompt(model,
flattened_prompt,
log_tokens=log_tokens,
log_display_label=f"(blend part {i+1}, weight={blend.weights[i]})" )
embeddings_to_blend = this_embedding if embeddings_to_blend is None else torch.cat(
(embeddings_to_blend, this_embedding))
conditioning = WeightedFrozenCLIPEmbedder.apply_embedding_weights(embeddings_to_blend.unsqueeze(0),
@ -103,14 +103,20 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
edit_options.append(None)
original_token_count += count
edited_token_count += count
original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, original_prompt, log_tokens=log_tokens)
original_embeddings, original_tokens = build_embeddings_and_tokens_for_flattened_prompt(model,
original_prompt,
log_tokens=log_tokens,
log_display_label="(.swap originals)")
# naïvely building a single edited_embeddings like this disregards the effects of changing the absolute location of
# subsequent tokens when there is >1 edit and earlier edits change the total token count.
# eg "a cat.swap(smiling dog, s_start=0.5) eating a hotdog.swap(pizza)" - when the 'pizza' edit is active but the
# 'cat' edit is not, the 'pizza' feature vector will nevertheless be affected by the introduction of the extra
# token 'smiling' in the inactive 'cat' edit.
# todo: build multiple edited_embeddings, one for each edit, and pass just the edited fragments through to the CrossAttentionControl functions
edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model, edited_prompt, log_tokens=log_tokens)
edited_embeddings, edited_tokens = build_embeddings_and_tokens_for_flattened_prompt(model,
edited_prompt,
log_tokens=log_tokens,
log_display_label="(.swap replacements)")
conditioning = original_embeddings
edited_conditioning = edited_embeddings
@ -121,9 +127,15 @@ def get_uc_and_c_and_ec(prompt_string_uncleaned, model, log_tokens=False, skip_n
edit_options = edit_options
)
else:
conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt, log_tokens=log_tokens)
conditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model,
flattened_prompt,
log_tokens=log_tokens,
log_display_label="(prompt)")
unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model, parsed_negative_prompt, log_tokens=log_tokens)
unconditioning, _ = build_embeddings_and_tokens_for_flattened_prompt(model,
parsed_negative_prompt,
log_tokens=log_tokens,
log_display_label="(unconditioning)")
if isinstance(conditioning, dict):
# hybrid conditioning is in play
unconditioning, conditioning = flatten_hybrid_conditioning(unconditioning, conditioning)
@ -144,26 +156,15 @@ def build_token_edit_opcodes(original_tokens, edited_tokens):
return SequenceMatcher(None, original_tokens, edited_tokens).get_opcodes()
def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt, log_tokens: bool=False):
def build_embeddings_and_tokens_for_flattened_prompt(model, flattened_prompt: FlattenedPrompt, log_tokens: bool=False, log_display_label: str=None):
if type(flattened_prompt) is not FlattenedPrompt:
raise Exception(f"embeddings can only be made from FlattenedPrompts, got {type(flattened_prompt)} instead")
fragments = [x.text for x in flattened_prompt.children]
weights = [x.weight for x in flattened_prompt.children]
embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True, fragment_weights=[weights])
if not flattened_prompt.is_empty and log_tokens:
start_token = model.cond_stage_model.tokenizer.bos_token_id
end_token = model.cond_stage_model.tokenizer.eos_token_id
tokens_list = tokens[0].tolist()
if tokens_list[0] == start_token:
tokens_list[0] = '<start>'
try:
first_end_token_index = tokens_list.index(end_token)
tokens_list[first_end_token_index] = '<end>'
tokens_list = tokens_list[:first_end_token_index+1]
except ValueError:
pass
print(f">> Prompt fragments {fragments}, tokenized to \n{tokens_list}")
if log_tokens:
text = " ".join(fragments)
log_tokenization(text, model, display_label=log_display_label)
return embeddings, tokens

View File

@ -2,6 +2,19 @@ import string
from typing import Union, Optional
import re
import pyparsing as pp
'''
This module parses prompt strings and produces tree-like structures that can be used generate and control the conditioning tensors.
weighted subprompts.
Useful class exports:
PromptParser - parses prompts
Useful function exports:
split_weighted_subpromopts() split subprompts, normalize and weight them
log_tokenization() print out colour-coded tokens and warn if truncated
'''
class Prompt():
"""
@ -205,12 +218,17 @@ class Blend():
#print("making Blend with prompts", prompts, "and weights", weights)
if len(prompts) != len(weights):
raise PromptParser.ParsingException(f"while parsing Blend: mismatched prompts/weights counts {prompts}, {weights}")
for c in prompts:
if type(c) is not Prompt and type(c) is not FlattenedPrompt:
raise(PromptParser.ParsingException(f"{type(c)} cannot be added to a Blend, only Prompts or FlattenedPrompts"))
for p in prompts:
if type(p) is not Prompt and type(p) is not FlattenedPrompt:
raise(PromptParser.ParsingException(f"{type(p)} cannot be added to a Blend, only Prompts or FlattenedPrompts"))
for f in p.children:
if isinstance(f, CrossAttentionControlSubstitute):
raise(PromptParser.ParsingException(f"while parsing Blend: sorry, you cannot do .swap() as part of a Blend"))
# upcast all lists to Prompt objects
self.prompts = [x if (type(x) is Prompt or type(x) is FlattenedPrompt)
else Prompt(x) for x in prompts]
else Prompt(x)
for x in prompts]
self.prompts = prompts
self.weights = weights
self.normalize_weights = normalize_weights
@ -662,9 +680,7 @@ def split_weighted_subprompts(text, skip_normalize=False)->list:
# shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' '
def log_tokenization(text, model, log=False, weight=1):
if not log:
return
def log_tokenization(text, model, display_label=None):
tokens = model.cond_stage_model.tokenizer._tokenize(text)
tokenized = ""
discarded = ""
@ -679,7 +695,7 @@ def log_tokenization(text, model, log=False, weight=1):
usedTokens += 1
else: # over max token length
discarded = discarded + f"\x1b[0;3{s};40m{token}"
print(f"\n>> Tokens ({usedTokens}), Weight ({weight:.2f}):\n{tokenized}\x1b[0m")
print(f"\n>> Tokens {display_label or ''} ({usedTokens}):\n{tokenized}\x1b[0m")
if discarded != "":
print(
f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m"

View File

@ -157,10 +157,11 @@ class InvokeAIDiffuserComponent:
# percent_through will never reach 1.0 (but this is intended)
return float(step_index) / float(self.cross_attention_control_context.step_count)
# find the best possible index of the current sigma in the sigma sequence
sigma_index = torch.nonzero(self.model.sigmas <= sigma)[-1]
smaller_sigmas = torch.nonzero(self.model.sigmas <= sigma)
sigma_index = smaller_sigmas[-1].item() if smaller_sigmas.shape[0] > 0 else 0
# flip because sigmas[0] is for the fully denoised image
# percent_through must be <1
return 1.0 - float(sigma_index.item() + 1) / float(self.model.sigmas.shape[0])
return 1.0 - float(sigma_index + 1) / float(self.model.sigmas.shape[0])
# print('estimated percent_through', percent_through, 'from sigma', sigma.item())

View File

@ -10,6 +10,9 @@ import warnings
import time
import traceback
import yaml
from ldm.invoke.prompt_parser import PromptParser
sys.path.append('.') # corrects a weird problem on Macs
from ldm.invoke.readline import get_completer
from ldm.invoke.args import Args, metadata_dumps, metadata_from_png, dream_cmd_from_png
@ -18,7 +21,7 @@ from ldm.invoke.image_util import make_grid
from ldm.invoke.log import write_log
from omegaconf import OmegaConf
from pathlib import Path
from pyparsing import ParseException
import pyparsing
# global used in multiple functions (fix)
infile = None
@ -340,7 +343,7 @@ def main_loop(gen, opt):
catch_interrupts=catch_ctrl_c,
**vars(opt)
)
except ParseException as e:
except (PromptParser.ParsingException, pyparsing.ParseException) as e:
print('** An error occurred while processing your prompt **')
print(f'** {str(e)} **')
elif operation == 'postprocess':