Bypass the 77 token limit (#2896)

This ought to be working but i don't know how it's supposed to behave so
i haven't been able to verify. At least, I know the numbers are getting
pushed all the way to the SD unet, i just have been unable to verify if
what's coming out is what is expected. Please test.

You'll `need to pip install -e .` after switching to the branch, because
it's currently pulling from a non-main `compel` branch. Once it's
verified as working as intended i'll promote the compel branch to pypi.
This commit is contained in:
blessedcoolant 2023-03-09 23:52:28 +13:00 committed by GitHub
commit 76d5fa4694
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 11 deletions

View File

@ -17,7 +17,7 @@ from compel.prompt_parser import (
Fragment, Fragment,
PromptParser, PromptParser,
) )
from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPTokenizer
from invokeai.backend.globals import Globals from invokeai.backend.globals import Globals
@ -71,6 +71,7 @@ def get_uc_and_c_and_ec(
text_encoder=text_encoder, text_encoder=text_encoder,
textual_inversion_manager=model.textual_inversion_manager, textual_inversion_manager=model.textual_inversion_manager,
dtype_for_device_getter=torch_dtype, dtype_for_device_getter=torch_dtype,
truncate_long_prompts=False
) )
# get rid of any newline characters # get rid of any newline characters
@ -82,12 +83,12 @@ def get_uc_and_c_and_ec(
legacy_blend = try_parse_legacy_blend( legacy_blend = try_parse_legacy_blend(
positive_prompt_string, skip_normalize_legacy_blend positive_prompt_string, skip_normalize_legacy_blend
) )
positive_prompt: FlattenedPrompt | Blend positive_prompt: Union[FlattenedPrompt, Blend]
if legacy_blend is not None: if legacy_blend is not None:
positive_prompt = legacy_blend positive_prompt = legacy_blend
else: else:
positive_prompt = Compel.parse_prompt_string(positive_prompt_string) positive_prompt = Compel.parse_prompt_string(positive_prompt_string)
negative_prompt: FlattenedPrompt | Blend = Compel.parse_prompt_string( negative_prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(
negative_prompt_string negative_prompt_string
) )
@ -96,6 +97,7 @@ def get_uc_and_c_and_ec(
c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt) c, options = compel.build_conditioning_tensor_for_prompt_object(positive_prompt)
uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt) uc, _ = compel.build_conditioning_tensor_for_prompt_object(negative_prompt)
[c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc])
tokens_count = get_max_token_count(tokenizer, positive_prompt) tokens_count = get_max_token_count(tokenizer, positive_prompt)
@ -116,12 +118,12 @@ def get_prompt_structure(
legacy_blend = try_parse_legacy_blend( legacy_blend = try_parse_legacy_blend(
positive_prompt_string, skip_normalize_legacy_blend positive_prompt_string, skip_normalize_legacy_blend
) )
positive_prompt: FlattenedPrompt | Blend positive_prompt: Union[FlattenedPrompt, Blend]
if legacy_blend is not None: if legacy_blend is not None:
positive_prompt = legacy_blend positive_prompt = legacy_blend
else: else:
positive_prompt = Compel.parse_prompt_string(positive_prompt_string) positive_prompt = Compel.parse_prompt_string(positive_prompt_string)
negative_prompt: FlattenedPrompt | Blend = Compel.parse_prompt_string( negative_prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(
negative_prompt_string negative_prompt_string
) )
@ -129,7 +131,7 @@ def get_prompt_structure(
def get_max_token_count( def get_max_token_count(
tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=True tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=False
) -> int: ) -> int:
if type(prompt) is Blend: if type(prompt) is Blend:
blend: Blend = prompt blend: Blend = prompt
@ -245,7 +247,7 @@ def log_tokenization_for_prompt_object(
) )
def log_tokenization_for_text(text, tokenizer, display_label=None): def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False):
"""shows how the prompt is tokenized """shows how the prompt is tokenized
# usually tokens have '</w>' to indicate end-of-word, # usually tokens have '</w>' to indicate end-of-word,
# but for readability it has been replaced with ' ' # but for readability it has been replaced with ' '
@ -260,11 +262,11 @@ def log_tokenization_for_text(text, tokenizer, display_label=None):
token = tokens[i].replace("</w>", " ") token = tokens[i].replace("</w>", " ")
# alternate color # alternate color
s = (usedTokens % 6) + 1 s = (usedTokens % 6) + 1
if i < tokenizer.model_max_length: if truncate_if_too_long and i >= tokenizer.model_max_length:
discarded = discarded + f"\x1b[0;3{s};40m{token}"
else:
tokenized = tokenized + f"\x1b[0;3{s};40m{token}" tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
usedTokens += 1 usedTokens += 1
else: # over max token length
discarded = discarded + f"\x1b[0;3{s};40m{token}"
if usedTokens > 0: if usedTokens > 0:
print(f'\n>> [TOKENLOG] Tokens {display_label or ""} ({usedTokens}):') print(f'\n>> [TOKENLOG] Tokens {display_label or ""} ({usedTokens}):')

View File

@ -38,7 +38,7 @@ dependencies = [
"albumentations", "albumentations",
"click", "click",
"clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip", "clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip",
"compel==0.1.7", "compel==0.1.10",
"datasets", "datasets",
"diffusers[torch]~=0.14", "diffusers[torch]~=0.14",
"dnspython==2.2.1", "dnspython==2.2.1",