mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge pull request #176 from xraxra/show-tokenization
Print out tokenization data during image generation, allowing truncated prompts to be visible.
This commit is contained in:
commit
7fe7cdc8c9
@ -215,6 +215,7 @@ class T2I:
|
|||||||
upscale=None,
|
upscale=None,
|
||||||
variants=None,
|
variants=None,
|
||||||
sampler_name=None,
|
sampler_name=None,
|
||||||
|
log_tokenization=False,
|
||||||
**args,
|
**args,
|
||||||
): # eat up additional cruft
|
): # eat up additional cruft
|
||||||
"""
|
"""
|
||||||
@ -260,6 +261,7 @@ class T2I:
|
|||||||
batch_size = batch_size or self.batch_size
|
batch_size = batch_size or self.batch_size
|
||||||
iterations = iterations or self.iterations
|
iterations = iterations or self.iterations
|
||||||
strength = strength or self.strength
|
strength = strength or self.strength
|
||||||
|
self.log_tokenization = log_tokenization
|
||||||
|
|
||||||
model = (
|
model = (
|
||||||
self.load_model()
|
self.load_model()
|
||||||
@ -503,6 +505,7 @@ class T2I:
|
|||||||
weight = weights[i]
|
weight = weights[i]
|
||||||
if not skip_normalize:
|
if not skip_normalize:
|
||||||
weight = weight / totalWeight
|
weight = weight / totalWeight
|
||||||
|
self._log_tokenization(subprompts[i])
|
||||||
c = torch.add(
|
c = torch.add(
|
||||||
c,
|
c,
|
||||||
self.model.get_learned_conditioning(
|
self.model.get_learned_conditioning(
|
||||||
@ -511,6 +514,7 @@ class T2I:
|
|||||||
alpha=weight,
|
alpha=weight,
|
||||||
)
|
)
|
||||||
else: # just standard 1 prompt
|
else: # just standard 1 prompt
|
||||||
|
self._log_tokenization(prompt)
|
||||||
c = self.model.get_learned_conditioning(batch_size * [prompt])
|
c = self.model.get_learned_conditioning(batch_size * [prompt])
|
||||||
return (uc, c)
|
return (uc, c)
|
||||||
|
|
||||||
@ -674,3 +678,27 @@ class T2I:
|
|||||||
weights.append(1.0)
|
weights.append(1.0)
|
||||||
remaining = 0
|
remaining = 0
|
||||||
return prompts, weights
|
return prompts, weights
|
||||||
|
|
||||||
|
# shows how the prompt is tokenized
|
||||||
|
# usually tokens have '</w>' to indicate end-of-word,
|
||||||
|
# but for readability it has been replaced with ' '
|
||||||
|
def _log_tokenization(self, text):
|
||||||
|
if not self.log_tokenization:
|
||||||
|
return
|
||||||
|
tokens = self.model.cond_stage_model.tokenizer._tokenize(text)
|
||||||
|
tokenized = ""
|
||||||
|
discarded = ""
|
||||||
|
usedTokens = 0
|
||||||
|
totalTokens = len(tokens)
|
||||||
|
for i in range(0,totalTokens):
|
||||||
|
token = tokens[i].replace('</w>',' ')
|
||||||
|
# alternate color
|
||||||
|
s = (usedTokens % 6) + 1
|
||||||
|
if i < self.model.cond_stage_model.max_length:
|
||||||
|
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
|
||||||
|
usedTokens += 1
|
||||||
|
else: # over max token length
|
||||||
|
discarded = discarded + f"\x1b[0;3{s};40m{token}"
|
||||||
|
print(f"\nTokens ({usedTokens}):\n{tokenized}\x1b[0m")
|
||||||
|
if discarded != "":
|
||||||
|
print(f"Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m")
|
||||||
|
@ -478,6 +478,12 @@ def create_cmd_parser():
|
|||||||
metavar='SAMPLER_NAME',
|
metavar='SAMPLER_NAME',
|
||||||
help=f'Switch to a different sampler. Supported samplers: {", ".join(SAMPLER_CHOICES)}',
|
help=f'Switch to a different sampler. Supported samplers: {", ".join(SAMPLER_CHOICES)}',
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'-t',
|
||||||
|
'--log_tokenization',
|
||||||
|
action='store_true',
|
||||||
|
help='shows how the prompt is split into tokens'
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user