InvokeAI/ldm/invoke/conditioning.py
2022-10-19 21:06:42 +02:00

139 lines
5.4 KiB
Python

'''
This module handles the generation of the conditioning tensors, including management of
weighted subprompts.
Useful function exports:
get_uc_and_c() get the conditioned and unconditioned latent
split_weighted_subpromopts() split subprompts, normalize and weight them
log_tokenization() print out colour-coded tokens and warn if truncated
'''
import re
from difflib import SequenceMatcher
import torch
def get_uc_and_c_and_ec(prompt, model, log_tokens=False, skip_normalize=False):
# Extract Unconditioned Words From Prompt
unconditioned_words = ''
unconditional_regex = r'\[(.*?)\]'
unconditionals = re.findall(unconditional_regex, prompt)
if len(unconditionals) > 0:
unconditioned_words = ' '.join(unconditionals)
# Remove Unconditioned Words From Prompt
unconditional_regex_compile = re.compile(unconditional_regex)
clean_prompt = unconditional_regex_compile.sub(' ', prompt)
prompt = re.sub(' +', ' ', clean_prompt)
edited_words = None
edited_regex = r'\{(.*?)\}'
edited = re.findall(edited_regex, prompt)
if len(edited) > 0:
edited_words = ' '.join(edited)
edited_regex_compile = re.compile(edited_regex)
clean_prompt = edited_regex_compile.sub(' ', prompt)
prompt = re.sub(' +', ' ', clean_prompt)
# get weighted sub-prompts
weighted_subprompts = split_weighted_subprompts(
prompt, skip_normalize
)
ec = None
edit_opcodes = None
uc, _ = model.get_learned_conditioning([unconditioned_words])
if len(weighted_subprompts) > 1:
# i dont know if this is correct.. but it works
c = torch.zeros_like(uc)
# normalize each "sub prompt" and add it
for subprompt, weight in weighted_subprompts:
log_tokenization(subprompt, model, log_tokens, weight)
subprompt_embeddings, _ = model.get_learned_conditioning([subprompt])
c = torch.add(
c,
subprompt_embeddings,
alpha=weight,
)
if edited_words is not None:
print("can't do cross-attention control with blends just yet, ignoring edits")
else: # just standard 1 prompt
log_tokenization(prompt, model, log_tokens, 1)
c, c_tokens = model.get_learned_conditioning([prompt])
if edited_words is not None:
ec, ec_tokens = model.get_learned_conditioning([edited_words])
edit_opcodes = build_token_edit_opcodes(c_tokens, ec_tokens)
return (uc, c, ec, edit_opcodes)
def build_token_edit_opcodes(c_tokens, ec_tokens):
tokens = c_tokens.cpu().numpy()[0]
tokens_edit = ec_tokens.cpu().numpy()[0]
opcodes = SequenceMatcher(None, tokens, tokens_edit).get_opcodes()
return opcodes
def split_weighted_subprompts(text, skip_normalize=False)->list:
"""
grabs all text up to the first occurrence of ':'
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
if ':' has no value defined, defaults to 1.0
repeats until no text remaining
"""
prompt_parser = re.compile("""
(?P<prompt> # capture group for 'prompt'
(?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:'
) # end 'prompt'
(?: # non-capture group
:+ # match one or more ':' characters
(?P<weight> # capture group for 'weight'
-?\d+(?:\.\d+)? # match positive or negative integer or decimal number
)? # end weight capture group, make optional
\s* # strip spaces after weight
| # OR
$ # else, if no ':' then match end of line
) # end non-capture group
""", re.VERBOSE)
parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float(
match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)]
if skip_normalize:
return parsed_prompts
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
if weight_sum == 0:
print(
"Warning: Subprompt weights add up to zero. Discarding and using even weights instead.")
equal_weight = 1 / max(len(parsed_prompts), 1)
return [(x[0], equal_weight) for x in parsed_prompts]
return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
# shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' '
def log_tokenization(text, model, log=False, weight=1):
if not log:
return
tokens = model.cond_stage_model.tokenizer._tokenize(text)
tokenized = ""
discarded = ""
usedTokens = 0
totalTokens = len(tokens)
for i in range(0, totalTokens):
token = tokens[i].replace('</w>', ' ')
# alternate color
s = (usedTokens % 6) + 1
if i < model.cond_stage_model.max_length:
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
usedTokens += 1
else: # over max token length
discarded = discarded + f"\x1b[0;3{s};40m{token}"
print(f"\n>> Tokens ({usedTokens}), Weight ({weight:.2f}):\n{tokenized}\x1b[0m")
if discarded != "":
print(
f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m"
)