Fix token display when using -t

Add true weight used for subprompt
This commit is contained in:
Bernard Maltais 2022-09-18 16:10:26 -04:00 committed by Lincoln Stein
parent f0b500fba8
commit 2743e17588

View File

@ -38,14 +38,14 @@ def get_uc_and_c(prompt, model, log_tokens=False, skip_normalize=False):
c = torch.zeros_like(uc) c = torch.zeros_like(uc)
# normalize each "sub prompt" and add it # normalize each "sub prompt" and add it
for subprompt, weight in weighted_subprompts: for subprompt, weight in weighted_subprompts:
log_tokenization(subprompt, model, log_tokens) log_tokenization(subprompt, model, log_tokens, weight)
c = torch.add( c = torch.add(
c, c,
model.get_learned_conditioning([subprompt]), model.get_learned_conditioning([subprompt]),
alpha=weight, alpha=weight,
) )
else: # just standard 1 prompt else: # just standard 1 prompt
log_tokenization(prompt, model, log_tokens) log_tokenization(prompt, model, log_tokens, 1)
c = model.get_learned_conditioning([prompt]) c = model.get_learned_conditioning([prompt])
uc = model.get_learned_conditioning([unconditioned_words]) uc = model.get_learned_conditioning([unconditioned_words])
return (uc, c) return (uc, c)
@ -86,7 +86,7 @@ def split_weighted_subprompts(text, skip_normalize=False)->list:
# shows how the prompt is tokenized # shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word, # usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' ' # but for readability it has been replaced with ' '
def log_tokenization(text, model, log=False): def log_tokenization(text, model, log=False, weight=1):
if not log: if not log:
return return
tokens = model.cond_stage_model.tokenizer._tokenize(text) tokens = model.cond_stage_model.tokenizer._tokenize(text)
@ -103,8 +103,8 @@ def log_tokenization(text, model, log=False):
usedTokens += 1 usedTokens += 1
else: # over max token length else: # over max token length
discarded = discarded + f"\x1b[0;3{s};40m{token}" discarded = discarded + f"\x1b[0;3{s};40m{token}"
print(f"\n>> Tokens ({usedTokens}):\n{tokenized}\x1b[0m") print(f"\n>> Tokens ({usedTokens}), Weight ({weight:.2f}):\n{tokenized}\x1b[0m")
if discarded != "": if discarded != "":
print( print(
f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m" f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m"
) )