mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
97 lines
3.8 KiB
Python
97 lines
3.8 KiB
Python
|
'''
|
||
|
This module handles the generation of the conditioning tensors, including management of
|
||
|
weighted subprompts.
|
||
|
|
||
|
Useful function exports:
|
||
|
|
||
|
get_uc_and_c() get the conditioned and unconditioned latent
|
||
|
split_weighted_subpromopts() split subprompts, normalize and weight them
|
||
|
log_tokenization() print out colour-coded tokens and warn if truncated
|
||
|
|
||
|
'''
|
||
|
import re
|
||
|
import torch
|
||
|
|
||
|
def get_uc_and_c(prompt, model, log_tokens=False, skip_normalize=False):
|
||
|
uc = model.get_learned_conditioning([''])
|
||
|
|
||
|
# get weighted sub-prompts
|
||
|
weighted_subprompts = split_weighted_subprompts(
|
||
|
prompt, skip_normalize
|
||
|
)
|
||
|
|
||
|
if len(weighted_subprompts) > 1:
|
||
|
# i dont know if this is correct.. but it works
|
||
|
c = torch.zeros_like(uc)
|
||
|
# normalize each "sub prompt" and add it
|
||
|
for subprompt, weight in weighted_subprompts:
|
||
|
log_tokenization(subprompt, model, log_tokens)
|
||
|
c = torch.add(
|
||
|
c,
|
||
|
model.get_learned_conditioning([subprompt]),
|
||
|
alpha=weight,
|
||
|
)
|
||
|
else: # just standard 1 prompt
|
||
|
log_tokenization(prompt, model, log_tokens)
|
||
|
c = model.get_learned_conditioning([prompt])
|
||
|
return (uc, c)
|
||
|
|
||
|
def split_weighted_subprompts(text, skip_normalize=False)->list:
|
||
|
"""
|
||
|
grabs all text up to the first occurrence of ':'
|
||
|
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
|
||
|
if ':' has no value defined, defaults to 1.0
|
||
|
repeats until no text remaining
|
||
|
"""
|
||
|
prompt_parser = re.compile("""
|
||
|
(?P<prompt> # capture group for 'prompt'
|
||
|
(?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:'
|
||
|
) # end 'prompt'
|
||
|
(?: # non-capture group
|
||
|
:+ # match one or more ':' characters
|
||
|
(?P<weight> # capture group for 'weight'
|
||
|
-?\d+(?:\.\d+)? # match positive or negative integer or decimal number
|
||
|
)? # end weight capture group, make optional
|
||
|
\s* # strip spaces after weight
|
||
|
| # OR
|
||
|
$ # else, if no ':' then match end of line
|
||
|
) # end non-capture group
|
||
|
""", re.VERBOSE)
|
||
|
parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float(
|
||
|
match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)]
|
||
|
if skip_normalize:
|
||
|
return parsed_prompts
|
||
|
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
|
||
|
if weight_sum == 0:
|
||
|
print(
|
||
|
"Warning: Subprompt weights add up to zero. Discarding and using even weights instead.")
|
||
|
equal_weight = 1 / len(parsed_prompts)
|
||
|
return [(x[0], equal_weight) for x in parsed_prompts]
|
||
|
return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
|
||
|
|
||
|
# shows how the prompt is tokenized
|
||
|
# usually tokens have '</w>' to indicate end-of-word,
|
||
|
# but for readability it has been replaced with ' '
|
||
|
def log_tokenization(text, model, log=False):
|
||
|
if not log:
|
||
|
return
|
||
|
tokens = model.cond_stage_model.tokenizer._tokenize(text)
|
||
|
tokenized = ""
|
||
|
discarded = ""
|
||
|
usedTokens = 0
|
||
|
totalTokens = len(tokens)
|
||
|
for i in range(0, totalTokens):
|
||
|
token = tokens[i].replace('</w>', ' ')
|
||
|
# alternate color
|
||
|
s = (usedTokens % 6) + 1
|
||
|
if i < model.cond_stage_model.max_length:
|
||
|
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
|
||
|
usedTokens += 1
|
||
|
else: # over max token length
|
||
|
discarded = discarded + f"\x1b[0;3{s};40m{token}"
|
||
|
print(f"\n>> Tokens ({usedTokens}):\n{tokenized}\x1b[0m")
|
||
|
if discarded != "":
|
||
|
print(
|
||
|
f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m"
|
||
|
)
|