tokenization logging (take 2)

This adds an option -t argument that will print out color-coded tokenization, SD has a maximum of 77 tokens, it silently discards tokens over the limit if your prompt is too long.
By using -t you can see how your prompt is being tokenized which helps prompt crafting.
This commit is contained in:
xra 2022-08-29 12:28:49 +09:00
parent 9a8cd9684e
commit fef632e0e1
2 changed files with 34 additions and 0 deletions

View File

@ -213,6 +213,7 @@ class T2I:
upscale=None,
variants=None,
sampler_name=None,
log_tokenization=False,
**args,
): # eat up additional cruft
"""
@ -253,6 +254,7 @@ class T2I:
batch_size = batch_size or self.batch_size
iterations = iterations or self.iterations
strength = strength or self.strength
self.log_tokenization = log_tokenization
model = (
self.load_model()
@ -489,6 +491,7 @@ class T2I:
weight = weights[i]
if not skip_normalize:
weight = weight / totalWeight
self._log_tokenization(subprompts[i])
c = torch.add(
c,
self.model.get_learned_conditioning(
@ -497,6 +500,7 @@ class T2I:
alpha=weight,
)
else: # just standard 1 prompt
self._log_tokenization(prompt)
c = self.model.get_learned_conditioning(batch_size * [prompt])
return (uc, c)
@ -657,3 +661,27 @@ class T2I:
weights.append(1.0)
remaining = 0
return prompts, weights
# shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' '
def _log_tokenization(self, text):
if not self.log_tokenization:
return
tokens = self.model.cond_stage_model.tokenizer._tokenize(text)
tokenized = ""
discarded = ""
usedTokens = 0
totalTokens = len(tokens)
for i in range(0,totalTokens):
token = tokens[i].replace('</w>',' ')
# alternate color
s = (usedTokens % 6) + 1
if i < self.model.cond_stage_model.max_length:
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
usedTokens += 1
else: # over max token length
discarded = discarded + f"\x1b[0;3{s};40m{token}"
print(f"\nTokens ({usedTokens}):\n{tokenized}\x1b[0m")
if discarded != "":
print(f"Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m")

View File

@ -462,6 +462,12 @@ def create_cmd_parser():
metavar='SAMPLER_NAME',
help=f'Switch to a different sampler. Supported samplers: {", ".join(SAMPLER_CHOICES)}',
)
parser.add_argument(
'-t',
'--log_tokenization',
action='store_true',
help='shows how the prompt is split into tokens'
)
return parser