Add --log_tokenization to sysargs

This allows the --log_tokenization option to be used as a command line argument (or from invokeai.init), making it possible to view tokenization information in the terminal when using the web interface.
This commit is contained in:
whosawhatsis 2023-02-04 19:56:20 -08:00
parent 0ca499cf96
commit 75b62d6ca8
2 changed files with 11 additions and 3 deletions

View File

@ -196,6 +196,7 @@ class Args(object):
elif os.path.exists(legacyinit):
print(f'>> WARNING: Old initialization file found at {legacyinit}. This location is deprecated. Please move it to {Globals.root}/invokeai.init.')
sysargs.insert(0,f'@{legacyinit}')
Globals.log_tokenization = self._arg_parser.parse_args(sysargs).log_tokenization
self._arg_switches = self._arg_parser.parse_args(sysargs)
return self._arg_switches
@ -599,6 +600,12 @@ class Args(object):
help=f'Set the default sampler. Supported samplers: {", ".join(SAMPLER_CHOICES)}',
default='k_lms',
)
render_group.add_argument(
'--log_tokenization',
'-t',
action='store_true',
help='shows how the prompt is split into tokens'
)
render_group.add_argument(
'-f',
'--strength',

View File

@ -17,6 +17,7 @@ from ..models.diffusion import cross_attention_control
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder
from ..modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter
from ldm.invoke.globals import Globals
def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False):
@ -92,7 +93,7 @@ def _get_conditioning_for_prompt(parsed_prompt: Union[Blend, FlattenedPrompt], p
Process prompt structure and tokens, and return (conditioning, unconditioning, extra_conditioning_info)
"""
if log_tokens:
if log_tokens or Globals.log_tokenization:
print(f">> Parsed prompt to {parsed_prompt}")
print(f">> Parsed negative prompt to {parsed_negative_prompt}")
@ -235,7 +236,7 @@ def _get_embeddings_and_tokens_for_prompt(model, flattened_prompt: FlattenedProm
fragments = [x.text for x in flattened_prompt.children]
weights = [x.weight for x in flattened_prompt.children]
embeddings, tokens = model.get_learned_conditioning([fragments], return_tokens=True, fragment_weights=[weights])
if log_tokens:
if log_tokens or Globals.log_tokenization:
text = " ".join(fragments)
log_tokenization(text, model, display_label=log_display_label)